Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
64163ab
Add input artifact
rui5i Aug 12, 2022
6efaa1d
Add input artifact
rui5i Aug 12, 2022
1a00e6e
Add unit tests
rui5i Aug 17, 2022
446f53e
Add the example on docstring
rui5i Aug 25, 2022
e3472ad
update the unit tests
rui5i Aug 25, 2022
ce24539
fix the key
rui5i Aug 25, 2022
c25744c
update unit test
rui5i Aug 25, 2022
d793e62
Merge branch 'main' into input_artifact
jaycee-li Aug 25, 2022
c31dbda
Merge branch 'main' into input_artifact
jaycee-li Aug 26, 2022
0b0ac36
Merge branch 'main' into input_artifact
rui5i Aug 30, 2022
dee3b7e
update the docstring to be more accurate
rui5i Aug 30, 2022
74cea46
update the docstring
rui5i Aug 30, 2022
2a8fa09
Merge branch 'main' into input_artifact
jaycee-li Aug 31, 2022
a4981e6
Merge branch 'main' into input_artifact
rui5i Aug 31, 2022
9c24502
Merge branch 'main' into input_artifact
jaycee-li Sep 1, 2022
27576c9
fix unit test
rui5i Sep 1, 2022
4874238
Merge branch 'main' into input_artifact
rui5i Sep 1, 2022
c775bed
fix unit test
rui5i Sep 2, 2022
f565492
Merge branch 'main' into input_artifact
rui5i Sep 2, 2022
21867ad
fix lint
rui5i Sep 2, 2022
2212959
Merge branch 'main' into input_artifact
rui5i Sep 2, 2022
b7442f9
Merge branch 'main' into input_artifact
jaycee-li Sep 3, 2022
be1079c
Merge branch 'main' into input_artifact
rui5i Sep 6, 2022
ab92b4c
Merge branch 'main' into input_artifact
rui5i Sep 6, 2022
227ccf4
Merge branch 'main' into input_artifact
rui5i Sep 6, 2022
a9a3887
Merge branch 'main' into input_artifact
rui5i Sep 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions google/cloud/aiplatform/pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def __init__(
job_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
parameter_values: Optional[Dict[str, Any]] = None,
input_artifacts: Optional[Dict[str, str]] = None,
enable_caching: Optional[bool] = None,
encryption_spec_key_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -139,6 +140,9 @@ def __init__(
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter names to its values that
control the pipeline run.
input_artifacts (Dict[str, str]):
Comment thread
SinaChavoshi marked this conversation as resolved.
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
enable_caching (bool):
Optional. Whether to turn on caching for the run.

Expand Down Expand Up @@ -235,6 +239,8 @@ def __init__(
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
builder.update_input_artifacts(input_artifacts)

builder.update_failure_policy(failure_policy)
runtime_config_dict = builder.build()

Expand Down Expand Up @@ -662,6 +668,7 @@ def clone(
job_id: Optional[str] = None,
pipeline_root: Optional[str] = None,
parameter_values: Optional[Dict[str, Any]] = None,
input_artifacts: Optional[Dict[str, str]] = None,
enable_caching: Optional[bool] = None,
encryption_spec_key_name: Optional[str] = None,
labels: Optional[Dict[str, str]] = None,
Expand All @@ -685,6 +692,9 @@ def clone(
Optional. The mapping from runtime parameter names to its values that
control the pipeline run. Defaults to be the same values as original
PipelineJob.
input_artifacts (Dict[str, str]):
Optional. The mapping from the runtime parameter name for this artifact to its resource id. Defaults to be the same values as original
PipelineJob. For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
enable_caching (bool):
Optional. Whether to turn on caching for the run.
If this is not set, defaults to be the same as original pipeline.
Expand Down Expand Up @@ -785,6 +795,7 @@ def clone(
)
builder.update_pipeline_root(pipeline_root)
builder.update_runtime_parameters(parameter_values)
builder.update_input_artifacts(input_artifacts)
runtime_config_dict = builder.build()
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(runtime_config_dict, runtime_config)
Expand All @@ -805,6 +816,7 @@ def from_pipeline_func(
# Parameters for the PipelineJob constructor
pipeline_func: Callable,
parameter_values: Optional[Dict[str, Any]] = None,
input_artifacts: Optional[Dict[str, str]] = None,
output_artifacts_gcs_dir: Optional[str] = None,
enable_caching: Optional[bool] = None,
context_name: Optional[str] = "pipeline",
Expand All @@ -827,6 +839,8 @@ def from_pipeline_func(
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter names to its values that
control the pipeline run.
input_artifacts (Dict[str, str]):
Optional. The mapping from the runtime parameter name for this artifact to its resource id. For example: "vertex_model":"456". Note: full resource name ("projects/123/locations/us-central1/metadataStores/default/artifacts/456") cannot be used.
output_artifacts_gcs_dir (str):
Optional. The GCS location of the pipeline outputs.
A GCS bucket for artifacts will be created if not specified.
Expand Down Expand Up @@ -907,6 +921,7 @@ def from_pipeline_func(
pipeline_job = PipelineJob(
template_path=pipeline_file,
parameter_values=parameter_values,
input_artifacts=input_artifacts,
pipeline_root=output_artifacts_gcs_dir,
enable_caching=enable_caching,
display_name=display_name,
Expand Down
19 changes: 19 additions & 0 deletions google/cloud/aiplatform/utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
schema_version: str,
parameter_types: Mapping[str, str],
parameter_values: Optional[Dict[str, Any]] = None,
input_artifacts: Optional[Dict[str, str]] = None,
failure_policy: Optional[pipeline_failure_policy.PipelineFailurePolicy] = None,
):
"""Creates a PipelineRuntimeConfigBuilder object.
Expand All @@ -46,6 +47,8 @@ def __init__(
Required. The mapping from pipeline parameter name to its type.
parameter_values (Dict[str, Any]):
Optional. The mapping from runtime parameter name to its value.
input_artifacts (Dict[str, str]):
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
failure_policy (pipeline_failure_policy.PipelineFailurePolicy):
Optional. Represents the failure policy of a pipeline. Currently, the
default of a pipeline is that the pipeline will continue to
Expand All @@ -59,6 +62,7 @@ def __init__(
self._schema_version = schema_version
self._parameter_types = parameter_types
self._parameter_values = copy.deepcopy(parameter_values or {})
self._input_artifacts = copy.deepcopy(input_artifacts or {})
self._failure_policy = failure_policy

@classmethod
Expand Down Expand Up @@ -129,6 +133,18 @@ def update_runtime_parameters(
parameters[k] = json.dumps(v)
self._parameter_values.update(parameters)

def update_input_artifacts(
self, input_artifacts: Optional[Mapping[str, str]]
) -> None:
"""Merges runtime input artifacts.

Args:
input_artifacts (Mapping[str, str]):
Optional. The mapping from the runtime parameter name for this artifact to its resource id.
"""
if input_artifacts:
self._input_artifacts.update(input_artifacts)

def update_failure_policy(self, failure_policy: Optional[str] = None) -> None:
"""Merges runtime failure policy.

Expand Down Expand Up @@ -172,6 +188,9 @@ def build(self) -> Dict[str, Any]:
for k, v in self._parameter_values.items()
if v is not None
},
"inputArtifacts": {
k: {"artifactId": v} for k, v in self._input_artifacts.items()
},
}

if self._failure_policy:
Expand Down
8 changes: 8 additions & 0 deletions tests/unit/aiplatform/test_pipeline_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
"struct_param": {"key1": 12345, "key2": 67890},
}

_TEST_PIPELINE_INPUT_ARTIFACTS = {
"vertex_model": "456",
}

_TEST_PIPELINE_SPEC_LEGACY_JSON = json.dumps(
{
"pipelineInfo": {"name": "my-pipeline"},
Expand Down Expand Up @@ -469,6 +473,7 @@ def test_run_call_pipeline_service_create(
template_path=_TEST_TEMPLATE_PATH,
job_id=_TEST_PIPELINE_JOB_ID,
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
)

Expand All @@ -485,6 +490,7 @@ def test_run_call_pipeline_service_create(
expected_runtime_config_dict = {
"gcsOutputDirectory": _TEST_GCS_BUCKET_NAME,
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand Down Expand Up @@ -1475,6 +1481,7 @@ def test_clone_pipeline_job_with_all_args(
job_id=f"cloned-{_TEST_PIPELINE_JOB_ID}",
pipeline_root=f"cloned-{_TEST_GCS_BUCKET_NAME}",
parameter_values=_TEST_PIPELINE_PARAMETER_VALUES,
input_artifacts=_TEST_PIPELINE_INPUT_ARTIFACTS,
enable_caching=True,
credentials=_TEST_CREDENTIALS,
project=_TEST_PROJECT,
Expand All @@ -1490,6 +1497,7 @@ def test_clone_pipeline_job_with_all_args(
expected_runtime_config_dict = {
"gcsOutputDirectory": f"cloned-{_TEST_GCS_BUCKET_NAME}",
"parameterValues": _TEST_PIPELINE_PARAMETER_VALUES,
"inputArtifacts": {"vertex_model": {"artifactId": "456"}},
}
runtime_config = gca_pipeline_job.PipelineJob.RuntimeConfig()._pb
json_format.ParseDict(expected_runtime_config_dict, runtime_config)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,7 @@ class TestPipelineUtils:
"int_param": {"intValue": 42},
"float_param": {"doubleValue": 3.14},
},
"inputArtifacts": {},
},
}

Expand Down Expand Up @@ -539,6 +540,7 @@ def test_pipeline_utils_runtime_config_builder_with_merge_updates(
"list_param": {"stringValue": "[1, 2, 3]"},
"bool_param": {"stringValue": "true"},
},
"inputArtifacts": {},
"failurePolicy": failure_policy[1],
}
assert expected_runtime_config == actual_runtime_config
Expand Down