diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index 76992595c4..2a3914efea 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -110,6 +110,7 @@ def __init__( job_id: Optional[str] = None, pipeline_root: Optional[str] = None, parameter_values: Optional[Dict[str, Any]] = None, + input_artifacts: Optional[Dict[str, str]] = None, enable_caching: Optional[bool] = None, encryption_spec_key_name: Optional[str] = None, labels: Optional[Dict[str, str]] = None, @@ -139,6 +140,9 @@ def __init__( parameter_values (Dict[str, Any]): Optional. The mapping from runtime parameter names to its values that control the pipeline run. + input_artifacts (Dict[str, str]): + Optional. The mapping from the runtime parameter name for this artifact to its resource id. + For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used. enable_caching (bool): Optional. Whether to turn on caching for the run. @@ -235,6 +239,8 @@ def __init__( ) builder.update_pipeline_root(pipeline_root) builder.update_runtime_parameters(parameter_values) + builder.update_input_artifacts(input_artifacts) + builder.update_failure_policy(failure_policy) runtime_config_dict = builder.build() @@ -662,6 +668,7 @@ def clone( job_id: Optional[str] = None, pipeline_root: Optional[str] = None, parameter_values: Optional[Dict[str, Any]] = None, + input_artifacts: Optional[Dict[str, str]] = None, enable_caching: Optional[bool] = None, encryption_spec_key_name: Optional[str] = None, labels: Optional[Dict[str, str]] = None, @@ -685,6 +692,9 @@ def clone( Optional. The mapping from runtime parameter names to its values that control the pipeline run. Defaults to be the same values as original PipelineJob. + input_artifacts (Dict[str, str]): + Optional. The mapping from the runtime parameter name for this artifact to its resource id. Defaults to be the same values as original + PipelineJob. For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used. enable_caching (bool): Optional. Whether to turn on caching for the run. If this is not set, defaults to be the same as original pipeline. @@ -785,6 +795,7 @@ def clone( ) builder.update_pipeline_root(pipeline_root) builder.update_runtime_parameters(parameter_values) + builder.update_input_artifacts(input_artifacts) runtime_config_dict = builder.build() runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb json_format.ParseDict(runtime_config_dict, runtime_config) @@ -805,6 +816,7 @@ def from_pipeline_func( # Parameters for the PipelineJob constructor pipeline_func: Callable, parameter_values: Optional[Dict[str, Any]] = None, + input_artifacts: Optional[Dict[str, str]] = None, output_artifacts_gcs_dir: Optional[str] = None, enable_caching: Optional[bool] = None, context_name: Optional[str] = "pipeline", @@ -827,6 +839,8 @@ def from_pipeline_func( parameter_values (Dict[str, Any]): Optional. The mapping from runtime parameter names to its values that control the pipeline run. + input_artifacts (Dict[str, str]): + Optional. The mapping from the runtime parameter name for this artifact to its resource id. For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used. output_artifacts_gcs_dir (str): Optional. The GCS location of the pipeline outputs. A GCS bucket for artifacts will be created if not specified. @@ -907,6 +921,7 @@ def from_pipeline_func( pipeline_job = PipelineJob( template_path=pipeline_file, parameter_values=parameter_values, + input_artifacts=input_artifacts, pipeline_root=output_artifacts_gcs_dir, enable_caching=enable_caching, display_name=display_name, diff --git a/google/cloud/aiplatform/utils/pipeline_utils.py b/google/cloud/aiplatform/utils/pipeline_utils.py index f988cc307e..ded7040c83 100644 --- a/google/cloud/aiplatform/utils/pipeline_utils.py +++ b/google/cloud/aiplatform/utils/pipeline_utils.py @@ -33,6 +33,7 @@ def __init__( schema_version: str, parameter_types: Mapping[str, str], parameter_values: Optional[Dict[str, Any]] = None, + input_artifacts: Optional[Dict[str, str]] = None, failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None, ): """Creates a PipelineRuntimeConfigBuilder object. @@ -46,6 +47,8 @@ def __init__( Required. The mapping from pipeline parameter name to its type. parameter_values (Dict[str, Any]): Optional. The mapping from runtime parameter name to its value. + input_artifacts (Dict[str, str]): + Optional. The mapping from the runtime parameter name for this artifact to its resource id. failure_policy (pipeline_failure_policy.PipelineFailurePolicy): Optional. Represents the failure policy of a pipeline. Currently, the default of a pipeline is that the pipeline will continue to @@ -59,6 +62,7 @@ def __init__( self._schema_version = schema_version self._parameter_types = parameter_types self._parameter_values = copy.deepcopy(parameter_values or {}) + self._input_artifacts = copy.deepcopy(input_artifacts or {}) self._failure_policy = failure_policy @classmethod @@ -129,6 +133,18 @@ def update_runtime_parameters( parameters[k] = json.dumps(v) self._parameter_values.update(parameters) + def update_input_artifacts( + self, input_artifacts: Optional[Mapping[str, str]] + ) -> None: + """Merges runtime input artifacts. + + Args: + input_artifacts (Mapping[str, str]): + Optional. The mapping from the runtime parameter name for this artifact to its resource id. + """ + if input_artifacts: + self._input_artifacts.update(input_artifacts) + def update_failure_policy(self, failure_policy: Optional[str] = None) -> None: """Merges runtime failure policy. @@ -172,6 +188,9 @@ def build(self) -> Dict[str, Any]: for k, v in self._parameter_values.items() if v is not None }, + "inputArtifacts": { + k: {"artifactId": v} for k, v in self._input_artifacts.items() + }, } if self._failure_policy: diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index 9975fedd5c..72d0ebcec1 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -74,6 +74,10 @@ "struct_param": {"key1": 12345, "key2": 67890}, } +_TEST_PIPELINE_INPUT_ARTIFACTS = { + "vertex_model": "456", +} + _TEST_PIPELINE_SPEC_LEGACY_JSON = json.dumps( { "pipelineInfo": {"name": "my-pipeline"}, @@ -469,6 +473,7 @@ def test_run_call_pipeline_service_create( template_path=_TEST_TEMPLATE_PATH, job_id=_TEST_PIPELINE_JOB_ID, parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, enable_caching=True, ) @@ -485,6 +490,7 @@ def test_run_call_pipeline_service_create( expected_runtime_config_dict = { "gcsOutputDirectory": _TEST_GCS_BUCKET_NAME, "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + "inputArtifacts": {"vertex_model": {"artifactId": "456"}}, } runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb json_format.ParseDict(expected_runtime_config_dict, runtime_config) @@ -1475,6 +1481,7 @@ def test_clone_pipeline_job_with_all_args( job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}", pipeline_root=f"cloned-{_TEST_GCS_BUCKET_NAME}", parameter_values=_TEST_PIPELINE_PARAMETER_VALUES, + input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS, enable_caching=True, credentials=_TEST_CREDENTIALS, project=_TEST_PROJECT, @@ -1490,6 +1497,7 @@ def test_clone_pipeline_job_with_all_args( expected_runtime_config_dict = { "gcsOutputDirectory": f"cloned-{_TEST_GCS_BUCKET_NAME}", "parameterValues": _TEST_PIPELINE_PARAMETER_VALUES, + "inputArtifacts": {"vertex_model": {"artifactId": "456"}}, } runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb json_format.ParseDict(expected_runtime_config_dict, runtime_config) diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 2f460ce9be..081d0ce18a 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -452,6 +452,7 @@ class TestPipelineUtils: "int_param": {"intValue": 42}, "float_param": {"doubleValue": 3.14}, }, + "inputArtifacts": {}, }, } @@ -539,6 +540,7 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates( "list_param": {"stringValue": "[1, 2, 3]"}, "bool_param": {"stringValue": "true"}, }, + "inputArtifacts": {}, "failurePolicy": failure_policy[1], } assert expected_runtime_config == actual_runtime_config