diff --git a/google/cloud/aiplatform/base.py b/google/cloud/aiplatform/base.py index d9c3994d26..10db4c0dd2 100644 --- a/google/cloud/aiplatform/base.py +++ b/google/cloud/aiplatform/base.py @@ -1077,6 +1077,13 @@ def _list( Returns: List[VertexAiResourceNoun] - A list of SDK resource objects """ + if parent: + parent_resources = utils.extract_project_and_location_from_parent(parent) + if parent_resources: + project, location = ( + parent_resources["project"], + parent_resources["location"], + ) resource = cls._empty_constructor( project=project, location=location, credentials=credentials diff --git a/google/cloud/aiplatform/utils/__init__.py b/google/cloud/aiplatform/utils/__init__.py index 8049c9376b..f0847aefae 100644 --- a/google/cloud/aiplatform/utils/__init__.py +++ b/google/cloud/aiplatform/utils/__init__.py @@ -325,6 +325,34 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona return (gcs_bucket, gcs_blob_prefix) +def extract_project_and_location_from_parent( + parent: str, +) -> Dict[str, str]: + """Given a complete parent resource name, return the project and location as a dict. + + Example Usage: + + parent_resources = extract_project_and_location_from_parent( + "projects/123/locations/us-central1/datasets/456" + ) + + parent_resources["project"] = "123" + parent_resources["location"] = "us-central1" + + Args: + parent (str): + Required. A complete parent resource name. + + Returns: + Dict[str, str] + A project, location dict from provided parent resource name. + """ + parent_resources = re.match( + r"^projects/(?P.+?)/locations/(?P.+?)(/|$)", parent + ) + return parent_resources.groupdict() if parent_resources else {} + + class ClientWithOverride: class WrappedClient: """Wrapper class for client that creates client at API invocation diff --git a/tests/unit/aiplatform/test_featurestores.py b/tests/unit/aiplatform/test_featurestores.py index 08d57be0c5..1dda189324 100644 --- a/tests/unit/aiplatform/test_featurestores.py +++ b/tests/unit/aiplatform/test_featurestores.py @@ -1023,7 +1023,23 @@ def test_list_entity_types(self, list_entity_types_mock): aiplatform.init(project=_TEST_PROJECT) my_featurestore = aiplatform.Featurestore( - featurestore_name=_TEST_FEATURESTORE_ID + featurestore_name=_TEST_FEATURESTORE_ID, + ) + my_entity_type_list = my_featurestore.list_entity_types() + + list_entity_types_mock.assert_called_once_with( + request={"parent": _TEST_FEATURESTORE_NAME} + ) + assert len(my_entity_type_list) == len(_TEST_ENTITY_TYPE_LIST) + for my_entity_type in my_entity_type_list: + assert type(my_entity_type) == aiplatform.EntityType + + @pytest.mark.usefixtures("get_featurestore_mock") + def test_list_entity_types_with_no_init(self, list_entity_types_mock): + my_featurestore = aiplatform.Featurestore( + featurestore_name=_TEST_FEATURESTORE_ID, + project=_TEST_PROJECT, + location=_TEST_LOCATION, ) my_entity_type_list = my_featurestore.list_entity_types() @@ -1762,7 +1778,7 @@ def test_update_entity_type(self, update_entity_type_mock): @pytest.mark.parametrize( "featurestore_name", [_TEST_FEATURESTORE_NAME, _TEST_FEATURESTORE_ID] ) - def test_list_entity_types(self, featurestore_name, list_entity_types_mock): + def test_list_entity_type(self, featurestore_name, list_entity_types_mock): aiplatform.init(project=_TEST_PROJECT) my_entity_type_list = aiplatform.EntityType.list( @@ -1790,6 +1806,23 @@ def test_list_features(self, list_features_mock): for my_feature in my_feature_list: assert type(my_feature) == aiplatform.Feature + @pytest.mark.usefixtures("get_entity_type_mock") + def test_list_features_with_no_init(self, list_features_mock): + my_entity_type = aiplatform.EntityType( + entity_type_name=_TEST_ENTITY_TYPE_ID, + featurestore_id=_TEST_FEATURESTORE_ID, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + my_feature_list = my_entity_type.list_features() + + list_features_mock.assert_called_once_with( + request={"parent": _TEST_ENTITY_TYPE_NAME} + ) + assert len(my_feature_list) == len(_TEST_FEATURE_LIST) + for my_feature in my_feature_list: + assert type(my_feature) == aiplatform.Feature + @pytest.mark.parametrize("sync", [True, False]) @pytest.mark.usefixtures("get_entity_type_mock", "get_feature_mock") def test_delete_features(self, delete_feature_mock, sync): diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index dd4587125e..e81866bfef 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -320,6 +320,30 @@ def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple) assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path) +@pytest.mark.parametrize( + "parent, expected", + [ + ( + "projects/123/locations/us-central1/datasets/456", + {"project": "123", "location": "us-central1"}, + ), + ( + "projects/123/locations/us-central1/", + {"project": "123", "location": "us-central1"}, + ), + ( + "projects/123/locations/us-central1", + {"project": "123", "location": "us-central1"}, + ), + ("projects/123/locations/", {}), + ("projects/123", {}), + ], +) +def test_extract_project_and_location_from_parent(parent: str, expected: tuple): + # Given a parent resource name, ensure correct project and location are extracted + assert expected == utils.extract_project_and_location_from_parent(parent) + + @pytest.mark.usefixtures("google_auth_mock") def test_wrapped_client(): test_client_info = gapic_v1.client_info.ClientInfo()