diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index af3815e47f..5ab03b75b5 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2427,7 +2427,8 @@ def update( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. bigquery_tables_log_ttl (int): - Optional. The TTL(time to live) of BigQuery tables in user projects + Optional. The number of days for which the logs are stored. + The TTL(time to live) of BigQuery tables in user projects which stores logs. A day is the basic unit of the TTL and we take the ceil of TTL/86400(a day). e.g. { second: 3600} indicates ttl = 1 @@ -2453,28 +2454,30 @@ def update( will be applied to all deployed models. """ self._sync_gca_resource() - current_job = self.api_client.get_model_deployment_monitoring_job( - name=self._gca_resource.name - ) + current_job = copy.deepcopy(self._gca_resource) update_mask: List[str] = [] if display_name is not None: update_mask.append("display_name") current_job.display_name = display_name if schedule_config is not None: update_mask.append("model_deployment_monitoring_schedule_config") - current_job.model_deployment_monitoring_schedule_config = schedule_config + current_job.model_deployment_monitoring_schedule_config = ( + schedule_config.as_proto() + ) if alert_config is not None: update_mask.append("model_monitoring_alert_config") - current_job.model_monitoring_alert_config = alert_config + current_job.model_monitoring_alert_config = alert_config.as_proto() if logging_sampling_strategy is not None: update_mask.append("logging_sampling_strategy") - current_job.logging_sampling_strategy = logging_sampling_strategy + current_job.logging_sampling_strategy = logging_sampling_strategy.as_proto() if labels is not None: update_mask.append("labels") - current_job.lables = labels + current_job.labels = labels if bigquery_tables_log_ttl is not None: update_mask.append("log_ttl") - current_job.log_ttl = bigquery_tables_log_ttl + current_job.log_ttl = duration_pb2.Duration( + seconds=bigquery_tables_log_ttl * 86400 + ) if enable_monitoring_pipeline_logs is not None: update_mask.append("enable_monitoring_pipeline_logs") current_job.enable_monitoring_pipeline_logs = ( @@ -2491,10 +2494,12 @@ def update( deployed_model_ids=deployed_model_ids, ) ) - self.api_client.update_model_deployment_monitoring_job( + # TODO: b/254285776 add optional_sync support to model monitoring job + lro = self.api_client.update_model_deployment_monitoring_job( model_deployment_monitoring_job=current_job, update_mask=field_mask_pb2.FieldMask(paths=update_mask), ) + self._gca_resource = lro.result() return self def pause(self) -> "ModelDeploymentMonitoringJob": diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 9d788df262..27a9ddea30 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -16,6 +16,7 @@ # import pytest +import copy from unittest import mock from importlib import reload @@ -24,6 +25,7 @@ from google.cloud import storage from google.cloud import bigquery +from google.api_core import operation from google.auth import credentials as auth_credentials from google.cloud import aiplatform @@ -46,7 +48,9 @@ job_service_client, ) from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore +import test_endpoints # noqa: F401 from test_endpoints import get_endpoint_with_models_mock # noqa: F401 _TEST_API_CLIENT = job_service_client.JobServiceClient @@ -175,6 +179,58 @@ _TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}" _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01} +_TEST_MDM_USER_EMAIL = "TEST_EMAIL" +_TEST_MDM_SAMPLE_RATE = 0.5 +_TEST_MDM_LABEL = {"TEST KEY": "TEST VAL"} +_TEST_LOG_TTL_IN_DAYS = 1 +_TEST_MDM_NEW_NAME = "NEW_NAME" + +_TEST_MDM_OLD_JOB = ( + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( + name=_TEST_MDM_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + endpoint=_TEST_ENDPOINT, + state=_TEST_JOB_STATE_RUNNING, + ) +) + +_TEST_MDM_EXPECTED_NEW_JOB = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( + name=_TEST_MDM_JOB_NAME, + display_name=_TEST_MDM_NEW_NAME, + endpoint=_TEST_ENDPOINT, + state=_TEST_JOB_STATE_RUNNING, + model_deployment_monitoring_objective_configs=[ + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( + deployed_model_id=model_id, + objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( + prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( + drift_thresholds={ + "TEST_KEY": gca_model_monitoring_compat.ThresholdConfig( + value=0.01 + ) + } + ) + ), + ) + for model_id in [model.id for model in test_endpoints._TEST_DEPLOYED_MODELS] + ], + logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy( + random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig( + sample_rate=_TEST_MDM_SAMPLE_RATE + ) + ), + labels=_TEST_MDM_LABEL, + model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig( + email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig( + user_emails=[_TEST_MDM_USER_EMAIL] + ) + ), + model_deployment_monitoring_schedule_config=gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig( + monitor_interval=duration_pb2.Duration(seconds=3600) + ), + log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), + enable_monitoring_pipeline_logs=True, +) # TODO(b/171333554): Move reusable test fixtures to conftest.py file @@ -988,48 +1044,23 @@ def get_mdm_job_mock(): with mock.patch.object( _TEST_API_CLIENT, "get_model_deployment_monitoring_job" ) as get_mdm_job_mock: - get_mdm_job_mock.return_value = ( - gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - name=_TEST_MDM_JOB_NAME, - display_name=_TEST_DISPLAY_NAME, - state=_TEST_JOB_STATE_RUNNING, - endpoint=_TEST_ENDPOINT, - ) - ) + get_mdm_job_mock.side_effect = [ + _TEST_MDM_OLD_JOB, + _TEST_MDM_OLD_JOB, + _TEST_MDM_OLD_JOB, + _TEST_MDM_EXPECTED_NEW_JOB, + ] yield get_mdm_job_mock @pytest.fixture -@pytest.mark.usefixtures("get_mdm_job_mock") def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 with mock.patch.object( _TEST_API_CLIENT, "update_model_deployment_monitoring_job" ) as update_mdm_job_mock: - expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( - prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( - drift_thresholds={ - "TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01) - } - ) - ) - all_configs = [] - for model in get_endpoint_with_models_mock.return_value.deployed_models: - all_configs.append( - gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( - deployed_model_id=model.id, - objective_config=expected_objective_config, - ) - ) - - update_mdm_job_mock.return_vaue.result_type = ( - gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - name=_TEST_MDM_JOB_NAME, - display_name=_TEST_DISPLAY_NAME, - state=_TEST_JOB_STATE_RUNNING, - endpoint=_TEST_ENDPOINT, - model_deployment_monitoring_objective_configs=all_configs, - ) - ) + update_mdm_job_lro_mock = mock.Mock(operation.Operation) + update_mdm_job_lro_mock.result.return_value = _TEST_MDM_EXPECTED_NEW_JOB + update_mdm_job_mock.return_value = update_mdm_job_lro_mock yield update_mdm_job_mock @@ -1046,25 +1077,66 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): job = jobs.ModelDeploymentMonitoringJob( model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME ) + old_job = copy.deepcopy(job._gca_resource) drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig( drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG ) + schedule_config = aiplatform.model_monitoring.ScheduleConfig(monitor_interval=1) + alert_config = aiplatform.model_monitoring.EmailAlertConfig( + user_emails=[_TEST_MDM_USER_EMAIL] + ) + sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig( + sample_rate=_TEST_MDM_SAMPLE_RATE + ) + labels = _TEST_MDM_LABEL + log_ttl = _TEST_LOG_TTL_IN_DAYS + display_name = _TEST_MDM_NEW_NAME new_config = aiplatform.model_monitoring.ObjectiveConfig( drift_detection_config=drift_detection_config ) - job.update(objective_configs=new_config) + job.update( + display_name=display_name, + schedule_config=schedule_config, + alert_config=alert_config, + logging_sampling_strategy=sampling_strategy, + labels=labels, + bigquery_tables_log_ttl=log_ttl, + enable_monitoring_pipeline_logs=True, + objective_configs=new_config, + ) + new_job = job._gca_resource + assert old_job != new_job + assert new_job.display_name == display_name + assert new_job.logging_sampling_strategy == sampling_strategy.as_proto() + assert ( + new_job.model_deployment_monitoring_schedule_config + == schedule_config.as_proto() + ) + assert new_job.labels == labels + assert new_job.model_monitoring_alert_config == alert_config.as_proto() + assert new_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS + assert new_job.enable_monitoring_pipeline_logs assert ( - job._gca_resource.model_deployment_monitoring_objective_configs[ + new_job.model_deployment_monitoring_objective_configs[ 0 ].objective_config.prediction_drift_detection_config == drift_detection_config.as_proto() ) get_mdm_job_mock.assert_called_with( - name=_TEST_MDM_JOB_NAME, + name=_TEST_MDM_JOB_NAME, retry=base._DEFAULT_RETRY ) update_mdm_job_mock.assert_called_once_with( - model_deployment_monitoring_job=get_mdm_job_mock.return_value, + model_deployment_monitoring_job=new_job, update_mask=field_mask_pb2.FieldMask( - paths=["model_deployment_monitoring_objective_configs"] + paths=[ + "display_name", + "model_deployment_monitoring_schedule_config", + "model_monitoring_alert_config", + "logging_sampling_strategy", + "labels", + "log_ttl", + "enable_monitoring_pipeline_logs", + "model_deployment_monitoring_objective_configs", + ] ), )