Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions google/cloud/aiplatform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,13 @@ def _list(
Returns:
List[VertexAiResourceNoun] - A list of SDK resource objects
"""
if parent:
parent_resources = utils.extract_project_and_location_from_parent(parent)
if parent_resources:
project, location = (
parent_resources["project"],
parent_resources["location"],
)

resource = cls._empty_constructor(
project=project, location=location, credentials=credentials
Expand Down
28 changes: 28 additions & 0 deletions google/cloud/aiplatform/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,34 @@ def extract_bucket_and_prefix_from_gcs_path(gcs_path: str) -> Tuple[str, Optiona
return (gcs_bucket, gcs_blob_prefix)


def extract_project_and_location_from_parent(
parent: str,
) -> Dict[str, str]:
"""Given a complete parent resource name, return the project and location as a dict.

Example Usage:

parent_resources = extract_project_and_location_from_parent(
"projects/123/locations/us-central1/datasets/456"
)

parent_resources["project"] = "123"
parent_resources["location"] = "us-central1"

Args:
parent (str):
Required. A complete parent resource name.

Returns:
Dict[str, str]
A project, location dict from provided parent resource name.
"""
parent_resources = re.match(
Comment thread
nayaknishant marked this conversation as resolved.
r"^projects/(?P<project>.+?)/locations/(?P<location>.+?)(/|$)", parent
)
return parent_resources.groupdict() if parent_resources else {}

Comment thread
nayaknishant marked this conversation as resolved.

class ClientWithOverride:
class WrappedClient:
"""Wrapper class for client that creates client at API invocation
Expand Down
37 changes: 35 additions & 2 deletions tests/unit/aiplatform/test_featurestores.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,7 +1023,23 @@ def test_list_entity_types(self, list_entity_types_mock):
aiplatform.init(project=_TEST_PROJECT)
Comment thread
nayaknishant marked this conversation as resolved.

my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_ID
featurestore_name=_TEST_FEATURESTORE_ID,
)
my_entity_type_list = my_featurestore.list_entity_types()

list_entity_types_mock.assert_called_once_with(
request={"parent": _TEST_FEATURESTORE_NAME}
)
assert len(my_entity_type_list) == len(_TEST_ENTITY_TYPE_LIST)
for my_entity_type in my_entity_type_list:
assert type(my_entity_type) == aiplatform.EntityType

@pytest.mark.usefixtures("get_featurestore_mock")
def test_list_entity_types_with_no_init(self, list_entity_types_mock):
my_featurestore = aiplatform.Featurestore(
featurestore_name=_TEST_FEATURESTORE_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
my_entity_type_list = my_featurestore.list_entity_types()

Expand Down Expand Up @@ -1762,7 +1778,7 @@ def test_update_entity_type(self, update_entity_type_mock):
@pytest.mark.parametrize(
"featurestore_name", [_TEST_FEATURESTORE_NAME, _TEST_FEATURESTORE_ID]
)
def test_list_entity_types(self, featurestore_name, list_entity_types_mock):
def test_list_entity_type(self, featurestore_name, list_entity_types_mock):
aiplatform.init(project=_TEST_PROJECT)

my_entity_type_list = aiplatform.EntityType.list(
Expand Down Expand Up @@ -1790,6 +1806,23 @@ def test_list_features(self, list_features_mock):
for my_feature in my_feature_list:
assert type(my_feature) == aiplatform.Feature

@pytest.mark.usefixtures("get_entity_type_mock")
def test_list_features_with_no_init(self, list_features_mock):
my_entity_type = aiplatform.EntityType(
entity_type_name=_TEST_ENTITY_TYPE_ID,
featurestore_id=_TEST_FEATURESTORE_ID,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
)
my_feature_list = my_entity_type.list_features()

list_features_mock.assert_called_once_with(
request={"parent": _TEST_ENTITY_TYPE_NAME}
)
assert len(my_feature_list) == len(_TEST_FEATURE_LIST)
for my_feature in my_feature_list:
assert type(my_feature) == aiplatform.Feature

@pytest.mark.parametrize("sync", [True, False])
@pytest.mark.usefixtures("get_entity_type_mock", "get_feature_mock")
def test_delete_features(self, delete_feature_mock, sync):
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,30 @@ def test_extract_bucket_and_prefix_from_gcs_path(gcs_path: str, expected: tuple)
assert expected == utils.extract_bucket_and_prefix_from_gcs_path(gcs_path)


@pytest.mark.parametrize(
"parent, expected",
[
(
"projects/123/locations/us-central1/datasets/456",
{"project": "123", "location": "us-central1"},
),
(
"projects/123/locations/us-central1/",
{"project": "123", "location": "us-central1"},
),
(
"projects/123/locations/us-central1",
{"project": "123", "location": "us-central1"},
),
("projects/123/locations/", {}),
("projects/123", {}),
],
)
def test_extract_project_and_location_from_parent(parent: str, expected: tuple):
# Given a parent resource name, ensure correct project and location are extracted
assert expected == utils.extract_project_and_location_from_parent(parent)


@pytest.mark.usefixtures("google_auth_mock")
def test_wrapped_client():
test_client_info = gapic_v1.client_info.ClientInfo()
Expand Down