diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index 55ca1c7829..d9c3994d26 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -48,6 +48,7 @@ from google.cloud.aiplatform.compat.types import encryption_spec as gca_encryption_spec from google.cloud.aiplatform.constants import base as base_constants from google.protobuf import json_format +from google.protobuf import field_mask_pb2 as field_mask # This is the default retry callback to be used with get methods. _DEFAULT_RETRY = retry.Retry() @@ -1030,6 +1031,7 @@ def _list( cls_filter: Callable[[proto.Message], bool] = lambda _: True, filter: Optional[str] = None, order_by: Optional[str] = None, + read_mask: Optional[field_mask.FieldMask] = None, project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, @@ -1052,6 +1054,14 @@ def _list( Optional. A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. Supported fields: `display_name`, `create_time`, `update_time` + read_mask (field_mask.FieldMask): + Optional. A FieldMask with a list of strings passed via `paths` + indicating which fields to return for each resource in the response. + For example, passing + field_mask.FieldMask(paths=["create_time", "update_time"]) + as `read_mask` would result in each returned VertexAiResourceNoun + in the result list only having the "create_time" and + "update_time" attributes. project (str): Optional. Project to retrieve list from. If not set, project set in aiplatform.init will be used. @@ -1067,6 +1077,7 @@ def _list( Returns: List[VertexAiResourceNoun] - A list of SDK resource objects """ + resource = cls._empty_constructor( project=project, location=location, credentials=credentials ) @@ -1083,6 +1094,10 @@ def _list( ), } + # `read_mask` is only passed from PipelineJob.list() for now + if read_mask is not None: + list_request["read_mask"] = read_mask + if filter: list_request["filter"] = filter @@ -1105,6 +1120,7 @@ def _list_with_local_order( cls_filter: Callable[[proto.Message], bool] = lambda _: True, filter: Optional[str] = None, order_by: Optional[str] = None, + read_mask: Optional[field_mask.FieldMask] = None, project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, @@ -1127,6 +1143,14 @@ def _list_with_local_order( Optional. A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. Supported fields: `display_name`, `create_time`, `update_time` + read_mask (field_mask.FieldMask): + Optional. A FieldMask with a list of strings passed via `paths` + indicating which fields to return for each resource in the response. + For example, passing + field_mask.FieldMask(paths=["create_time", "update_time"]) + as `read_mask` would result in each returned VertexAiResourceNoun + in the result list only having the "create_time" and + "update_time" attributes. project (str): Optional. Project to retrieve list from. If not set, project set in aiplatform.init will be used. @@ -1145,6 +1169,7 @@ def _list_with_local_order( cls_filter=cls_filter, filter=filter, order_by=None, # This method will handle the ordering locally + read_mask=read_mask, project=project, location=location, credentials=credentials, diff --git a/google/cloud/aiplatform/constants/pipeline.py b/google/cloud/aiplatform/constants/pipeline.py index d4ff2aa32e..12acf8a52e 100644 --- a/google/cloud/aiplatform/constants/pipeline.py +++ b/google/cloud/aiplatform/constants/pipeline.py @@ -37,3 +37,20 @@ # Pattern for an Artifact Registry URL. _VALID_AR_URL = re.compile(r"^https:\/\/([\w-]+)-kfp\.pkg\.dev\/.*") + +# Fields to include in returned PipelineJob when enable_simple_view=True in PipelineJob.list() +_READ_MASK_FIELDS = [ + "name", + "state", + "display_name", + "pipeline_spec.pipeline_info", + "create_time", + "start_time", + "end_time", + "update_time", + "labels", + "template_uri", + "template_metadata.version", + "job_detail.pipeline_run_context", + "job_detail.pipeline_context", +] diff --git a/google/cloud/aiplatform/pipeline_jobs.py b/google/cloud/aiplatform/pipeline_jobs.py index afc5dd86ad..4b9e95730c 100644 --- a/google/cloud/aiplatform/pipeline_jobs.py +++ b/google/cloud/aiplatform/pipeline_jobs.py @@ -38,6 +38,7 @@ from google.cloud.aiplatform.utils import yaml_utils from google.cloud.aiplatform.utils import pipeline_utils from google.protobuf import json_format +from google.protobuf import field_mask_pb2 as field_mask from google.cloud.aiplatform.compat.types import ( pipeline_job as gca_pipeline_job, @@ -56,6 +57,8 @@ # Pattern for an Artifact Registry URL. _VALID_AR_URL = pipeline_constants._VALID_AR_URL +_READ_MASK_FIELDS = pipeline_constants._READ_MASK_FIELDS + def _get_current_time() -> datetime.datetime: """Gets the current timestamp.""" @@ -509,6 +512,7 @@ def list( cls, filter: Optional[str] = None, order_by: Optional[str] = None, + enable_simple_view: Optional[bool] = False, project: Optional[str] = None, location: Optional[str] = None, credentials: Optional[auth_credentials.Credentials] = None, @@ -530,6 +534,17 @@ def list( Optional. A comma-separated list of fields to order by, sorted in ascending order. Use "desc" after a field name for descending. Supported fields: `display_name`, `create_time`, `update_time` + enable_simple_view (bool): + Optional. Whether to pass the `read_mask` parameter to the list call. + This will improve the performance of calling list(). However, the + returned PipelineJob list will not include all fields for each PipelineJob. + Setting this to True will exclude the following fields in your response: + `runtime_config`, `service_account`, `network`, and some subfields of + `pipeline_spec` and `job_detail`. The following fields will be included in + each PipelineJob resource in your response: `state`, `display_name`, + `pipeline_spec.pipeline_info`, `create_time`, `start_time`, `end_time`, + `update_time`, `labels`, `template_uri`, `template_metadata.version`, + `job_detail.pipeline_run_context`, `job_detail.pipeline_context`. project (str): Optional. Project to retrieve list from. If not set, project set in aiplatform.init will be used. @@ -544,9 +559,18 @@ def list( List[PipelineJob] - A list of PipelineJob resource objects """ + read_mask_fields = None + + if enable_simple_view: + read_mask_fields = field_mask.FieldMask(paths=_READ_MASK_FIELDS) + _LOGGER.warn( + "By enabling simple view, the PipelineJob resources returned from this method will not contain all fields." + ) + return cls._list_with_local_order( filter=filter, order_by=order_by, + read_mask=read_mask_fields, project=project, location=location, credentials=credentials, diff --git a/tests/system/aiplatform/test_pipeline_job.py b/tests/system/aiplatform/test_pipeline_job.py index 004ad768a3..13d8207a1d 100644 --- a/tests/system/aiplatform/test_pipeline_job.py +++ b/tests/system/aiplatform/test_pipeline_job.py @@ -20,6 +20,8 @@ from google.cloud import aiplatform from tests.system.aiplatform import e2e_base +from google.protobuf.json_format import MessageToDict + @pytest.mark.usefixtures("tear_down_resources") class TestPipelineJob(e2e_base.TestEndToEnd): @@ -59,3 +61,14 @@ def training_pipeline(number_of_epochs: int = 10): shared_state.setdefault("resources", []).append(job) job.wait() + + list_with_read_mask = aiplatform.PipelineJob.list(enable_simple_view=True) + list_without_read_mask = aiplatform.PipelineJob.list() + + # enable_simple_view=True should apply the `read_mask` filter to limit PipelineJob fields returned + assert "serviceAccount" in MessageToDict( + list_without_read_mask[0].gca_resource._pb + ) + assert "serviceAccount" not in MessageToDict( + list_with_read_mask[0].gca_resource._pb + ) diff --git a/tests/unit/aiplatform/test_pipeline_jobs.py b/tests/unit/aiplatform/test_pipeline_jobs.py index 4976fd6d4c..a608e7df07 100644 --- a/tests/unit/aiplatform/test_pipeline_jobs.py +++ b/tests/unit/aiplatform/test_pipeline_jobs.py @@ -28,6 +28,7 @@ from google.cloud import aiplatform from google.cloud.aiplatform import base from google.cloud.aiplatform import initializer +from google.cloud.aiplatform.constants import pipeline as pipeline_constants from google.cloud.aiplatform_v1 import Context as GapicContext from google.cloud.aiplatform_v1 import MetadataStore as GapicMetadataStore from google.cloud.aiplatform.metadata import constants @@ -37,6 +38,7 @@ from google.cloud.aiplatform.utils import gcs_utils from google.cloud import storage from google.protobuf import json_format +from google.protobuf import field_mask_pb2 as field_mask from google.cloud.aiplatform.compat.services import ( pipeline_service_client, @@ -62,6 +64,9 @@ _TEST_NETWORK = f"projects/{_TEST_PROJECT}/global/networks/{_TEST_PIPELINE_JOB_ID}" _TEST_PIPELINE_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/pipelineJobs/{_TEST_PIPELINE_JOB_ID}" +_TEST_PIPELINE_JOB_LIST_READ_MASK = field_mask.FieldMask( + paths=pipeline_constants._READ_MASK_FIELDS +) _TEST_PIPELINE_PARAMETER_VALUES_LEGACY = {"string_param": "hello"} _TEST_PIPELINE_PARAMETER_VALUES = { @@ -332,6 +337,17 @@ def mock_pipeline_service_list(): with mock.patch.object( pipeline_service_client.PipelineServiceClient, "list_pipeline_jobs" ) as mock_list_pipeline_jobs: + mock_list_pipeline_jobs.return_value = [ + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + make_pipeline_job( + gca_pipeline_state.PipelineState.PIPELINE_STATE_SUCCEEDED + ), + ] yield mock_list_pipeline_jobs @@ -1354,6 +1370,47 @@ def test_list_pipeline_job( request={"parent": _TEST_PARENT} ) + @pytest.mark.usefixtures( + "mock_pipeline_service_create", + "mock_pipeline_service_get", + "mock_pipeline_bucket_exists", + ) + @pytest.mark.parametrize( + "job_spec", + [ + _TEST_PIPELINE_SPEC_JSON, + _TEST_PIPELINE_SPEC_YAML, + _TEST_PIPELINE_JOB, + _TEST_PIPELINE_SPEC_LEGACY_JSON, + _TEST_PIPELINE_SPEC_LEGACY_YAML, + _TEST_PIPELINE_JOB_LEGACY, + ], + ) + def test_list_pipeline_job_with_read_mask( + self, mock_pipeline_service_list, mock_load_yaml_and_json + ): + aiplatform.init( + project=_TEST_PROJECT, + staging_bucket=_TEST_GCS_BUCKET_NAME, + credentials=_TEST_CREDENTIALS, + ) + + job = pipeline_jobs.PipelineJob( + display_name=_TEST_PIPELINE_JOB_DISPLAY_NAME, + template_path=_TEST_TEMPLATE_PATH, + job_id=_TEST_PIPELINE_JOB_ID, + ) + + job.run() + job.list(enable_simple_view=True) + + mock_pipeline_service_list.assert_called_once_with( + request={ + "parent": _TEST_PARENT, + "read_mask": _TEST_PIPELINE_JOB_LIST_READ_MASK, + }, + ) + @pytest.mark.usefixtures( "mock_pipeline_service_create", "mock_pipeline_service_get",