From 27e5f55c32a2d78ce6c9628863794d826f6c819d Mon Sep 17 00:00:00 2001 From: Kevin Naughton Date: Thu, 28 Jul 2022 18:19:31 +0000 Subject: [PATCH 1/5] Add google.ClassificationMetrics, google.RegressionMetrics, and google.ForecastingMetrics Artifact types to metadata schema with unit tests. --- .../metadata/schema/google/artifact_schema.py | 248 ++++++++++++++++++ tests/unit/aiplatform/test_metadata_schema.py | 138 ++++++++++ 2 files changed, 386 insertions(+) diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py index 99e0fb0ba6..74df93ec32 100644 --- a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py +++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py @@ -268,3 +268,251 @@ def __init__( metadata=extended_metadata, state=state, ) + + +class ClassificationMetrics(base_artifact.BaseArtifactSchema): + """A Google artifact representing evaluation Classification Metrics.""" + + schema_title = "google.ClassificationMetrics" + + def __init__( + self, + *, + au_prc: Optional[float] = None, + au_roc: Optional[float] = None, + log_loss: Optional[float] = None, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + au_prc (float): + Optional. The Area Under Precision-Recall Curve metric. + Micro-averaged for the overall evaluation. + au_roc (float): + Optional. The Area Under Receiver Operating Characteristic curve metric. + Micro-averaged for the overall evaluation. + log_loss (float): + Optional. The Log Loss metric. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + if au_prc is not None: + extended_metadata["auPrc"] = au_prc + if au_roc is not None: + extended_metadata["auRoc"] = au_roc + if log_loss is not None: + extended_metadata["logLoss"] = log_loss + + super(UnmanagedContainerModel, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class RegressionMetrics(base_artifact.BaseArtifactSchema): + """A Google artifact representing evaluation Regression Metrics.""" + + schema_title = "google.RegressionMetrics" + + def __init__( + self, + *, + root_mean_squared_error: Optional[float] = None, + mean_absolute_error: Optional[float] = None, + mean_absolute_percentage_error: Optional[float] = None, + r_squared: Optional[float] = None, + root_mean_squared_log_error: Optional[float] = None, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + root_mean_squared_error (float): + Optional. Root Mean Squared Error (RMSE). + mean_absolute_error (float): + Optional. Mean Absolute Error (MAE). + mean_absolute_percentage_error (float): + Optional. Mean absolute percentage error. + r_squared (float): + Optional. Coefficient of determination as Pearson correlation coefficient. + root_mean_squared_log_error (float): + Optional. Root mean squared log error. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + if root_mean_squared_error is not None: + extended_metadata["rootMeanSquaredError"] = root_mean_squared_error + if mean_absolute_error is not None: + extended_metadata["meanAbsoluteError"] = mean_absolute_error + if mean_absolute_percentage_error is not None: + extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error + if r_squared is not None: + extended_metadata["rSquared"] = r_squared + if root_mean_squared_log_error is not None: + extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error + + super(RegressionMetrics, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class ForecastingMetrics(base_artifact.BaseArtifactSchema): + """A Google artifact representing evaluation Forecasting Metrics.""" + + schema_title = "google.ForecastingMetrics" + + def __init__( + self, + *, + root_mean_squared_error: Optional[float] = None, + mean_absolute_error: Optional[float] = None, + mean_absolute_percentage_error: Optional[float] = None, + r_squared: Optional[float] = None, + root_mean_squared_log_error: Optional[float] = None, + weighted_absolute_percentage_error: Optional[float] = None, + root_mean_squared_percentage_error: Optional[float] = None, + symmetric_mean_absolute_percentage_error: Optional[float] = None, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + root_mean_squared_error (float): + Optional. Root Mean Squared Error (RMSE). + mean_absolute_error (float): + Optional. Mean Absolute Error (MAE). + mean_absolute_percentage_error (float): + Optional. Mean absolute percentage error. + r_squared (float): + Optional. Coefficient of determination as Pearson correlation coefficient. + root_mean_squared_log_error (float): + Optional. Root mean squared log error. + weighted_absolute_percentage_error (float): + Optional. Weighted Absolute Percentage Error. + Does not use weights, this is just what the metric is called. + Undefined if actual values sum to zero. + Will be very large if actual values sum to a very small number. + root_mean_squared_percentage_error (float): + Optional. Root Mean Square Percentage Error. Square root of MSPE. + Undefined/imaginary when MSPE is negative. + symmetric_mean_absolute_percentage_error (float): + Optional. Symmetric Mean Absolute Percentage Error. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + if root_mean_squared_error is not None: + extended_metadata["rootMeanSquaredError"] = root_mean_squared_error + if mean_absolute_error is not None: + extended_metadata["meanAbsoluteError"] = mean_absolute_error + if mean_absolute_percentage_error is not None: + extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error + if r_squared is not None: + extended_metadata["rSquared"] = r_squared + if root_mean_squared_log_error is not None: + extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error + if weighted_absolute_percentage_error is not None: + extended_metadata["weightedAbsolutePercentageError"] = weighted_absolute_percentage_error + if root_mean_squared_percentage_error is not None: + extended_metadata["rootMeanSquaredPercentageError"] = root_mean_squared_percentage_error + if symmetric_mean_absolute_percentage_error is not None: + extended_metadata["symmetricMeanAbsolutePercentageError"] = symmetric_mean_absolute_percentage_error + + super(ForecastingMetrics, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index f550f61ab4..ffb812686a 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -461,6 +461,144 @@ def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self assert artifact.schema_version == _TEST_SCHEMA_VERSION + def test_classification_metrics_title_is_set_correctly(self): + artifact = google_artifact_schema.ClassificationMetrics() + assert artifact.schema_title == "google.ClassificationMetrics" + + def test_classification_metrics_constructor_parameters_are_set_correctly(self): + au_prc = 1.0 + au_roc = 2.0 + log_loss = 0.5 + + artifact = google_artifact_schema.ClassificationMetrics( + au_prc=au_prc, + au_roc=au_roc, + log_loss=log_loss, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + expected_metadata = { + "test-param1": 2.0, + "test-param2": "test-value-1", + "test-param3": False, + "auPrc": 1.0, + "auRoc": 2.0, + "logLoss": 0.5, + } + + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + expected_metadata, sort_keys=True + ) + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + + def test_regression_metrics_title_is_set_correctly(self): + artifact = google_artifact_schema.RegressionMetrics() + assert artifact.schema_title == "google.RegressionMetrics" + + def test_regression_metrics_constructor_parameters_are_set_correctly(self): + root_mean_squared_error = 1.0 + mean_absolute_error = 2.0 + mean_absolute_percentage_error = 0.2 + r_squared = 0.5 + root_mean_squared_log_error = 0.9 + + artifact = google_artifact_schema.RegressionMetrics( + root_mean_squared_error=root_mean_squared_error, + mean_absolute_error=mean_absolute_error, + mean_absolute_percentage_error=mean_absolute_percentage_error, + r_squared=r_squared, + root_mean_squared_log_error=root_mean_squared_log_error, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + expected_metadata = { + "test-param1": 2.0, + "test-param2": "test-value-1", + "test-param3": False, + "rootMeanSquaredError": 1.0, + "meanAbsoluteError": 2.0, + "meanAbsolutePercentageError": 0.2, + "rSquared": 0.5, + "rootMeanSquaredLogError": 0.9, + } + + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + expected_metadata, sort_keys=True + ) + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + + def test_forecasting_metrics_title_is_set_correctly(self): + artifact = google_artifact_schema.ForecastingMetrics() + assert artifact.schema_title == "google.ForecastingMetrics" + + def test_forecasting_metrics_constructor_parameters_are_set_correctly(self): + root_mean_squared_error = 1.0 + mean_absolute_error = 2.0 + mean_absolute_percentage_error = 0.2 + r_squared = 0.5 + root_mean_squared_log_error = 0.9 + weighted_absolute_percentage_error = 4.0 + root_mean_squared_percentage_error = 0.7 + symmetric_mean_absolute_percentage_error = 0.8 + + artifact = google_artifact_schema.UnmanagedContainerModel( + root_mean_squared_error=root_mean_squared_error, + mean_absolute_error=mean_absolute_error, + mean_absolute_percentage_error=mean_absolute_percentage_error, + r_squared=r_squared, + root_mean_squared_log_error=root_mean_squared_log_error, + weighted_absolute_percentage_error=weighted_absolute_percentage_error, + root_mean_squared_percentage_error=root_mean_squared_percentage_error, + symmetric_mean_absolute_percentage_error=symmetric_mean_absolute_percentage_error, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + expected_metadata = { + "test-param1": 2.0, + "test-param2": "test-value-1", + "test-param3": False, + "rootMeanSquaredError": 1.0, + "meanAbsoluteError": 2.0, + "meanAbsolutePercentageError": 0.2, + "rSquared": 0.5, + "rootMeanSquaredLogError": 0.9, + "weightedAbsolutePercentageError": 4.0, + "rootMeanSquaredPercentageError": 0.7, + "symmetricMeanAbsolutePercentageError": 0.8, + } + + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + expected_metadata, sort_keys=True + ) + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + @pytest.mark.usefixtures("google_auth_mock") class TestMetadataSystemArtifactSchema: def setup_method(self): From eb951131a862d2f64019f99b28b38066c8fd969d Mon Sep 17 00:00:00 2001 From: Kevin Naughton Date: Thu, 28 Jul 2022 19:50:20 +0000 Subject: [PATCH 2/5] fix implicit false --- .../metadata/schema/google/artifact_schema.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py index 74df93ec32..cd6b67d1e4 100644 --- a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py +++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py @@ -323,11 +323,11 @@ def __init__( check the validity of state transitions. """ extended_metadata = copy.deepcopy(metadata) if metadata else {} - if au_prc is not None: + if au_prc: extended_metadata["auPrc"] = au_prc - if au_roc is not None: + if au_roc: extended_metadata["auRoc"] = au_roc - if log_loss is not None: + if log_loss: extended_metadata["logLoss"] = log_loss super(UnmanagedContainerModel, self).__init__( @@ -398,15 +398,15 @@ def __init__( check the validity of state transitions. """ extended_metadata = copy.deepcopy(metadata) if metadata else {} - if root_mean_squared_error is not None: + if root_mean_squared_error: extended_metadata["rootMeanSquaredError"] = root_mean_squared_error - if mean_absolute_error is not None: + if mean_absolute_error: extended_metadata["meanAbsoluteError"] = mean_absolute_error - if mean_absolute_percentage_error is not None: + if mean_absolute_percentage_error: extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error - if r_squared is not None: + if r_squared: extended_metadata["rSquared"] = r_squared - if root_mean_squared_log_error is not None: + if root_mean_squared_log_error: extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error super(RegressionMetrics, self).__init__( @@ -490,21 +490,21 @@ def __init__( check the validity of state transitions. """ extended_metadata = copy.deepcopy(metadata) if metadata else {} - if root_mean_squared_error is not None: + if root_mean_squared_error: extended_metadata["rootMeanSquaredError"] = root_mean_squared_error - if mean_absolute_error is not None: + if mean_absolute_error: extended_metadata["meanAbsoluteError"] = mean_absolute_error - if mean_absolute_percentage_error is not None: + if mean_absolute_percentage_error: extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error - if r_squared is not None: + if r_squared: extended_metadata["rSquared"] = r_squared - if root_mean_squared_log_error is not None: + if root_mean_squared_log_error: extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error - if weighted_absolute_percentage_error is not None: + if weighted_absolute_percentage_error: extended_metadata["weightedAbsolutePercentageError"] = weighted_absolute_percentage_error - if root_mean_squared_percentage_error is not None: + if root_mean_squared_percentage_error: extended_metadata["rootMeanSquaredPercentageError"] = root_mean_squared_percentage_error - if symmetric_mean_absolute_percentage_error is not None: + if symmetric_mean_absolute_percentage_error: extended_metadata["symmetricMeanAbsolutePercentageError"] = symmetric_mean_absolute_percentage_error super(ForecastingMetrics, self).__init__( From 358e87605c00358f0db86dfbcca3c086a1f88fd9 Mon Sep 17 00:00:00 2001 From: Kevin Naughton Date: Thu, 28 Jul 2022 21:50:20 +0000 Subject: [PATCH 3/5] Fix typo --- .../cloud/aiplatform/metadata/schema/google/artifact_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py index cd6b67d1e4..7366d461e5 100644 --- a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py +++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py @@ -330,7 +330,7 @@ def __init__( if log_loss: extended_metadata["logLoss"] = log_loss - super(UnmanagedContainerModel, self).__init__( + super(ClassificationMetrics, self).__init__( uri=uri, artifact_id=artifact_id, display_name=display_name, From a5ce95e5c51fe502fb8f374028edd6d9f94c16cd Mon Sep 17 00:00:00 2001 From: Kevin Naughton Date: Mon, 1 Aug 2022 19:20:10 +0000 Subject: [PATCH 4/5] Running nox -s blacken and nox -s lint --- .../metadata/schema/google/artifact_schema.py | 22 ++++++++++++++----- tests/unit/aiplatform/test_metadata_schema.py | 3 --- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py index 7366d461e5..e52f2f98b5 100644 --- a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py +++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py @@ -296,7 +296,7 @@ def __init__( au_roc (float): Optional. The Area Under Receiver Operating Characteristic curve metric. Micro-averaged for the overall evaluation. - log_loss (float): + log_loss (float): Optional. The Log Loss metric. artifact_id (str): Optional. The portion of the Artifact name with @@ -403,7 +403,9 @@ def __init__( if mean_absolute_error: extended_metadata["meanAbsoluteError"] = mean_absolute_error if mean_absolute_percentage_error: - extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error + extended_metadata[ + "meanAbsolutePercentageError" + ] = mean_absolute_percentage_error if r_squared: extended_metadata["rSquared"] = r_squared if root_mean_squared_log_error: @@ -495,17 +497,25 @@ def __init__( if mean_absolute_error: extended_metadata["meanAbsoluteError"] = mean_absolute_error if mean_absolute_percentage_error: - extended_metadata["meanAbsolutePercentageError"] = mean_absolute_percentage_error + extended_metadata[ + "meanAbsolutePercentageError" + ] = mean_absolute_percentage_error if r_squared: extended_metadata["rSquared"] = r_squared if root_mean_squared_log_error: extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error if weighted_absolute_percentage_error: - extended_metadata["weightedAbsolutePercentageError"] = weighted_absolute_percentage_error + extended_metadata[ + "weightedAbsolutePercentageError" + ] = weighted_absolute_percentage_error if root_mean_squared_percentage_error: - extended_metadata["rootMeanSquaredPercentageError"] = root_mean_squared_percentage_error + extended_metadata[ + "rootMeanSquaredPercentageError" + ] = root_mean_squared_percentage_error if symmetric_mean_absolute_percentage_error: - extended_metadata["symmetricMeanAbsolutePercentageError"] = symmetric_mean_absolute_percentage_error + extended_metadata[ + "symmetricMeanAbsolutePercentageError" + ] = symmetric_mean_absolute_percentage_error super(ForecastingMetrics, self).__init__( uri=uri, diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index ffb812686a..d3476ffc43 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -460,7 +460,6 @@ def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self ) assert artifact.schema_version == _TEST_SCHEMA_VERSION - def test_classification_metrics_title_is_set_correctly(self): artifact = google_artifact_schema.ClassificationMetrics() assert artifact.schema_title == "google.ClassificationMetrics" @@ -499,7 +498,6 @@ def test_classification_metrics_constructor_parameters_are_set_correctly(self): ) assert artifact.schema_version == _TEST_SCHEMA_VERSION - def test_regression_metrics_title_is_set_correctly(self): artifact = google_artifact_schema.RegressionMetrics() assert artifact.schema_title == "google.RegressionMetrics" @@ -544,7 +542,6 @@ def test_regression_metrics_constructor_parameters_are_set_correctly(self): ) assert artifact.schema_version == _TEST_SCHEMA_VERSION - def test_forecasting_metrics_title_is_set_correctly(self): artifact = google_artifact_schema.ForecastingMetrics() assert artifact.schema_title == "google.ForecastingMetrics" From b94e637cb9b2ac186bca757add8ca0d5a5aeca66 Mon Sep 17 00:00:00 2001 From: Kevin Naughton Date: Mon, 1 Aug 2022 21:43:34 +0000 Subject: [PATCH 5/5] fix typo in unit test --- tests/unit/aiplatform/test_metadata_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index d3476ffc43..e85229c555 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -556,7 +556,7 @@ def test_forecasting_metrics_constructor_parameters_are_set_correctly(self): root_mean_squared_percentage_error = 0.7 symmetric_mean_absolute_percentage_error = 0.8 - artifact = google_artifact_schema.UnmanagedContainerModel( + artifact = google_artifact_schema.ForecastingMetrics( root_mean_squared_error=root_mean_squared_error, mean_absolute_error=mean_absolute_error, mean_absolute_percentage_error=mean_absolute_percentage_error,