diff --git a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py index 99e0fb0ba6..e52f2f98b5 100644 --- a/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py +++ b/google/cloud/aiplatform/metadata/schema/google/artifact_schema.py @@ -268,3 +268,261 @@ def __init__( metadata=extended_metadata, state=state, ) + + +class ClassificationMetrics(base_artifact.BaseArtifactSchema): + """A Google artifact representing evaluation Classification Metrics.""" + + schema_title = "google.ClassificationMetrics" + + def __init__( + self, + *, + au_prc: Optional[float] = None, + au_roc: Optional[float] = None, + log_loss: Optional[float] = None, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + au_prc (float): + Optional. The Area Under Precision-Recall Curve metric. + Micro-averaged for the overall evaluation. + au_roc (float): + Optional. The Area Under Receiver Operating Characteristic curve metric. + Micro-averaged for the overall evaluation. + log_loss (float): + Optional. The Log Loss metric. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + if au_prc: + extended_metadata["auPrc"] = au_prc + if au_roc: + extended_metadata["auRoc"] = au_roc + if log_loss: + extended_metadata["logLoss"] = log_loss + + super(ClassificationMetrics, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class RegressionMetrics(base_artifact.BaseArtifactSchema): + """A Google artifact representing evaluation Regression Metrics.""" + + schema_title = "google.RegressionMetrics" + + def __init__( + self, + *, + root_mean_squared_error: Optional[float] = None, + mean_absolute_error: Optional[float] = None, + mean_absolute_percentage_error: Optional[float] = None, + r_squared: Optional[float] = None, + root_mean_squared_log_error: Optional[float] = None, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + root_mean_squared_error (float): + Optional. Root Mean Squared Error (RMSE). + mean_absolute_error (float): + Optional. Mean Absolute Error (MAE). + mean_absolute_percentage_error (float): + Optional. Mean absolute percentage error. + r_squared (float): + Optional. Coefficient of determination as Pearson correlation coefficient. + root_mean_squared_log_error (float): + Optional. Root mean squared log error. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + if root_mean_squared_error: + extended_metadata["rootMeanSquaredError"] = root_mean_squared_error + if mean_absolute_error: + extended_metadata["meanAbsoluteError"] = mean_absolute_error + if mean_absolute_percentage_error: + extended_metadata[ + "meanAbsolutePercentageError" + ] = mean_absolute_percentage_error + if r_squared: + extended_metadata["rSquared"] = r_squared + if root_mean_squared_log_error: + extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error + + super(RegressionMetrics, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) + + +class ForecastingMetrics(base_artifact.BaseArtifactSchema): + """A Google artifact representing evaluation Forecasting Metrics.""" + + schema_title = "google.ForecastingMetrics" + + def __init__( + self, + *, + root_mean_squared_error: Optional[float] = None, + mean_absolute_error: Optional[float] = None, + mean_absolute_percentage_error: Optional[float] = None, + r_squared: Optional[float] = None, + root_mean_squared_log_error: Optional[float] = None, + weighted_absolute_percentage_error: Optional[float] = None, + root_mean_squared_percentage_error: Optional[float] = None, + symmetric_mean_absolute_percentage_error: Optional[float] = None, + artifact_id: Optional[str] = None, + uri: Optional[str] = None, + display_name: Optional[str] = None, + schema_version: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict] = None, + state: Optional[gca_artifact.Artifact.State] = gca_artifact.Artifact.State.LIVE, + ): + """Args: + root_mean_squared_error (float): + Optional. Root Mean Squared Error (RMSE). + mean_absolute_error (float): + Optional. Mean Absolute Error (MAE). + mean_absolute_percentage_error (float): + Optional. Mean absolute percentage error. + r_squared (float): + Optional. Coefficient of determination as Pearson correlation coefficient. + root_mean_squared_log_error (float): + Optional. Root mean squared log error. + weighted_absolute_percentage_error (float): + Optional. Weighted Absolute Percentage Error. + Does not use weights, this is just what the metric is called. + Undefined if actual values sum to zero. + Will be very large if actual values sum to a very small number. + root_mean_squared_percentage_error (float): + Optional. Root Mean Square Percentage Error. Square root of MSPE. + Undefined/imaginary when MSPE is negative. + symmetric_mean_absolute_percentage_error (float): + Optional. Symmetric Mean Absolute Percentage Error. + artifact_id (str): + Optional. The portion of the Artifact name with + the format. This is globally unique in a metadataStore: + projects/123/locations/us-central1/metadataStores//artifacts/. + uri (str): + Optional. The uniform resource identifier of the artifact file. May be empty if there is no actual + artifact file. + display_name (str): + Optional. The user-defined name of the Artifact. + schema_version (str): + Optional. schema_version specifies the version used by the Artifact. + If not set, defaults to use the latest version. + description (str): + Optional. Describes the purpose of the Artifact to be created. + metadata (Dict): + Optional. Contains the metadata information that will be stored in the Artifact. + state (google.cloud.gapic.types.Artifact.State): + Optional. The state of this Artifact. This is a + property of the Artifact, and does not imply or + capture any ongoing process. This property is + managed by clients (such as Vertex AI + Pipelines), and the system does not prescribe or + check the validity of state transitions. + """ + extended_metadata = copy.deepcopy(metadata) if metadata else {} + if root_mean_squared_error: + extended_metadata["rootMeanSquaredError"] = root_mean_squared_error + if mean_absolute_error: + extended_metadata["meanAbsoluteError"] = mean_absolute_error + if mean_absolute_percentage_error: + extended_metadata[ + "meanAbsolutePercentageError" + ] = mean_absolute_percentage_error + if r_squared: + extended_metadata["rSquared"] = r_squared + if root_mean_squared_log_error: + extended_metadata["rootMeanSquaredLogError"] = root_mean_squared_log_error + if weighted_absolute_percentage_error: + extended_metadata[ + "weightedAbsolutePercentageError" + ] = weighted_absolute_percentage_error + if root_mean_squared_percentage_error: + extended_metadata[ + "rootMeanSquaredPercentageError" + ] = root_mean_squared_percentage_error + if symmetric_mean_absolute_percentage_error: + extended_metadata[ + "symmetricMeanAbsolutePercentageError" + ] = symmetric_mean_absolute_percentage_error + + super(ForecastingMetrics, self).__init__( + uri=uri, + artifact_id=artifact_id, + display_name=display_name, + schema_version=schema_version, + description=description, + metadata=extended_metadata, + state=state, + ) diff --git a/tests/unit/aiplatform/test_metadata_schema.py b/tests/unit/aiplatform/test_metadata_schema.py index f550f61ab4..e85229c555 100644 --- a/tests/unit/aiplatform/test_metadata_schema.py +++ b/tests/unit/aiplatform/test_metadata_schema.py @@ -460,6 +460,141 @@ def test_unmanaged_container_model_constructor_parameters_are_set_correctly(self ) assert artifact.schema_version == _TEST_SCHEMA_VERSION + def test_classification_metrics_title_is_set_correctly(self): + artifact = google_artifact_schema.ClassificationMetrics() + assert artifact.schema_title == "google.ClassificationMetrics" + + def test_classification_metrics_constructor_parameters_are_set_correctly(self): + au_prc = 1.0 + au_roc = 2.0 + log_loss = 0.5 + + artifact = google_artifact_schema.ClassificationMetrics( + au_prc=au_prc, + au_roc=au_roc, + log_loss=log_loss, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + expected_metadata = { + "test-param1": 2.0, + "test-param2": "test-value-1", + "test-param3": False, + "auPrc": 1.0, + "auRoc": 2.0, + "logLoss": 0.5, + } + + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + expected_metadata, sort_keys=True + ) + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_regression_metrics_title_is_set_correctly(self): + artifact = google_artifact_schema.RegressionMetrics() + assert artifact.schema_title == "google.RegressionMetrics" + + def test_regression_metrics_constructor_parameters_are_set_correctly(self): + root_mean_squared_error = 1.0 + mean_absolute_error = 2.0 + mean_absolute_percentage_error = 0.2 + r_squared = 0.5 + root_mean_squared_log_error = 0.9 + + artifact = google_artifact_schema.RegressionMetrics( + root_mean_squared_error=root_mean_squared_error, + mean_absolute_error=mean_absolute_error, + mean_absolute_percentage_error=mean_absolute_percentage_error, + r_squared=r_squared, + root_mean_squared_log_error=root_mean_squared_log_error, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + expected_metadata = { + "test-param1": 2.0, + "test-param2": "test-value-1", + "test-param3": False, + "rootMeanSquaredError": 1.0, + "meanAbsoluteError": 2.0, + "meanAbsolutePercentageError": 0.2, + "rSquared": 0.5, + "rootMeanSquaredLogError": 0.9, + } + + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + expected_metadata, sort_keys=True + ) + assert artifact.schema_version == _TEST_SCHEMA_VERSION + + def test_forecasting_metrics_title_is_set_correctly(self): + artifact = google_artifact_schema.ForecastingMetrics() + assert artifact.schema_title == "google.ForecastingMetrics" + + def test_forecasting_metrics_constructor_parameters_are_set_correctly(self): + root_mean_squared_error = 1.0 + mean_absolute_error = 2.0 + mean_absolute_percentage_error = 0.2 + r_squared = 0.5 + root_mean_squared_log_error = 0.9 + weighted_absolute_percentage_error = 4.0 + root_mean_squared_percentage_error = 0.7 + symmetric_mean_absolute_percentage_error = 0.8 + + artifact = google_artifact_schema.ForecastingMetrics( + root_mean_squared_error=root_mean_squared_error, + mean_absolute_error=mean_absolute_error, + mean_absolute_percentage_error=mean_absolute_percentage_error, + r_squared=r_squared, + root_mean_squared_log_error=root_mean_squared_log_error, + weighted_absolute_percentage_error=weighted_absolute_percentage_error, + root_mean_squared_percentage_error=root_mean_squared_percentage_error, + symmetric_mean_absolute_percentage_error=symmetric_mean_absolute_percentage_error, + artifact_id=_TEST_ARTIFACT_ID, + uri=_TEST_URI, + display_name=_TEST_DISPLAY_NAME, + schema_version=_TEST_SCHEMA_VERSION, + description=_TEST_DESCRIPTION, + metadata=_TEST_UPDATED_METADATA, + ) + expected_metadata = { + "test-param1": 2.0, + "test-param2": "test-value-1", + "test-param3": False, + "rootMeanSquaredError": 1.0, + "meanAbsoluteError": 2.0, + "meanAbsolutePercentageError": 0.2, + "rSquared": 0.5, + "rootMeanSquaredLogError": 0.9, + "weightedAbsolutePercentageError": 4.0, + "rootMeanSquaredPercentageError": 0.7, + "symmetricMeanAbsolutePercentageError": 0.8, + } + + assert artifact.artifact_id == _TEST_ARTIFACT_ID + assert artifact.uri == _TEST_URI + assert artifact.display_name == _TEST_DISPLAY_NAME + assert artifact.description == _TEST_DESCRIPTION + assert json.dumps(artifact.metadata, sort_keys=True) == json.dumps( + expected_metadata, sort_keys=True + ) + assert artifact.schema_version == _TEST_SCHEMA_VERSION + @pytest.mark.usefixtures("google_auth_mock") class TestMetadataSystemArtifactSchema: