diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 203362d7a1..af3815e47f 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2484,32 +2484,31 @@ def update( update_mask.append("model_deployment_monitoring_objective_configs") current_job.model_deployment_monitoring_objective_configs = ( ModelDeploymentMonitoringJob._parse_configs( - objective_configs, - current_job.endpoint, - deployed_model_ids, + objective_configs=objective_configs, + endpoint=aiplatform.Endpoint( + current_job.endpoint, credentials=self.credentials + ), + deployed_model_ids=deployed_model_ids, ) ) - if self.state == gca_job_state.JobState.JOB_STATE_RUNNING: - self.api_client.update_model_deployment_monitoring_job( - model_deployment_monitoring_job=current_job, - update_mask=field_mask_pb2.FieldMask(paths=update_mask), - ) + self.api_client.update_model_deployment_monitoring_job( + model_deployment_monitoring_job=current_job, + update_mask=field_mask_pb2.FieldMask(paths=update_mask), + ) return self def pause(self) -> "ModelDeploymentMonitoringJob": """Pause a running MDM job.""" - if self.state == gca_job_state.JobState.JOB_STATE_RUNNING: - self.api_client.pause_model_deployment_monitoring_job( - name=self._gca_resource.name - ) + self.api_client.pause_model_deployment_monitoring_job( + name=self._gca_resource.name + ) return self def resume(self) -> "ModelDeploymentMonitoringJob": """Resumes a paused MDM job.""" - if self.state == gca_job_state.JobState.JOB_STATE_PAUSED: - self.api_client.resume_model_deployment_monitoring_job( - name=self._gca_resource.name - ) + self.api_client.resume_model_deployment_monitoring_job( + name=self._gca_resource.name + ) return self def delete(self) -> None: diff --git a/tests/system/aiplatform/test_model_monitoring.py b/tests/system/aiplatform/test_model_monitoring.py index ab59e9e3de..cf347337ff 100644 --- a/tests/system/aiplatform/test_model_monitoring.py +++ b/tests/system/aiplatform/test_model_monitoring.py @@ -31,8 +31,28 @@ # constants used for testing USER_EMAIL = "" -PERMANENT_CHURN_ENDPOINT_ID = "8289570005524152320" +PERMANENT_CHURN_ENDPOINT_ID = "1843089351408353280" CHURN_MODEL_PATH = "gs://mco-mm/churn" +DEFAULT_INPUT = { + "cnt_ad_reward": 0, + "cnt_challenge_a_friend": 0, + "cnt_completed_5_levels": 1, + "cnt_level_complete_quickplay": 3, + "cnt_level_end_quickplay": 5, + "cnt_level_reset_quickplay": 2, + "cnt_level_start_quickplay": 6, + "cnt_post_score": 34, + "cnt_spend_virtual_currency": 0, + "cnt_use_extra_steps": 0, + "cnt_user_engagement": 120, + "country": "Denmark", + "dayofweek": 3, + "julianday": 254, + "language": "da-dk", + "month": 9, + "operating_system": "IOS", + "user_pseudo_id": "104B0770BAE16E8B53DF330C95881893", +} JOB_NAME = "churn" @@ -117,10 +137,7 @@ def test_mdm_two_models_one_valid_config(self): project=e2e_base._PROJECT, location=e2e_base._LOCATION, endpoint=self.endpoint, - predict_instance_schema_uri="", - analysis_instance_schema_uri="", ) - assert job is not None gapic_job = job._gca_resource assert ( @@ -156,22 +173,77 @@ def test_mdm_two_models_one_valid_config(self): gca_obj_config.prediction_drift_detection_config == drift_config.as_proto() ) + # delete this job and re-configure it to only enable drift detection for faster testing + job.delete() job_resource = job._gca_resource.name - # test job update and delete() - timeout = time.time() + 3600 - new_obj_config = model_monitoring.ObjectiveConfig(skew_config) + # test job delete + with pytest.raises(core_exceptions.NotFound): + job.api_client.get_model_deployment_monitoring_job(name=job_resource) + + def test_mdm_pause_and_update_config(self): + """Test objective config updates for existing MDM job""" + job = aiplatform.ModelDeploymentMonitoringJob.create( + display_name=self._make_display_name(key=JOB_NAME), + logging_sampling_strategy=sampling_strategy, + schedule_config=schedule_config, + alert_config=alert_config, + objective_configs=model_monitoring.ObjectiveConfig( + drift_detection_config=drift_config + ), + create_request_timeout=3600, + project=e2e_base._PROJECT, + location=e2e_base._LOCATION, + endpoint=self.endpoint, + ) + # test unsuccessful job update when it's pending + DRIFT_THRESHOLDS["cnt_user_engagement"] += 0.01 + new_obj_config = model_monitoring.ObjectiveConfig( + drift_detection_config=model_monitoring.DriftDetectionConfig( + drift_thresholds=DRIFT_THRESHOLDS, + attribute_drift_thresholds=ATTRIB_DRIFT_THRESHOLDS, + ) + ) + if job.state == gca_job_state.JobState.JOB_STATE_PENDING: + with pytest.raises(core_exceptions.FailedPrecondition): + job.update(objective_configs=new_obj_config) + + # generate traffic to force MDM job to come online + for i in range(2000): + DEFAULT_INPUT["cnt_user_engagement"] += i + self.endpoint.predict([DEFAULT_INPUT], use_raw_predict=True) - while time.time() < timeout: + # test job update + while True: + time.sleep(1) if job.state == gca_job_state.JobState.JOB_STATE_RUNNING: job.update(objective_configs=new_obj_config) - assert str(job._gca_resource.prediction_drift_detection_config) == "" break - time.sleep(5) + # verify job update + while True: + time.sleep(1) + if job.state == gca_job_state.JobState.JOB_STATE_RUNNING: + gca_obj_config = ( + job._gca_resource.model_deployment_monitoring_objective_configs[ + 0 + ].objective_config + ) + assert ( + gca_obj_config.prediction_drift_detection_config + == new_obj_config.drift_detection_config.as_proto() + ) + break + + # test pause + job.pause() + while job.state != gca_job_state.JobState.JOB_STATE_PAUSED: + time.sleep(1) job.delete() + + # confirm deletion with pytest.raises(core_exceptions.NotFound): - job.api_client.get_model_deployment_monitoring_job(name=job_resource) + job.state def test_mdm_two_models_two_valid_configs(self): [deployed_model1, deployed_model2] = list( @@ -181,7 +253,6 @@ def test_mdm_two_models_two_valid_configs(self): deployed_model1: objective_config, deployed_model2: objective_config2, } - job = None job = aiplatform.ModelDeploymentMonitoringJob.create( display_name=self._make_display_name(key=JOB_NAME), logging_sampling_strategy=sampling_strategy, @@ -192,10 +263,7 @@ def test_mdm_two_models_two_valid_configs(self): project=e2e_base._PROJECT, location=e2e_base._LOCATION, endpoint=self.endpoint, - predict_instance_schema_uri="", - analysis_instance_schema_uri="", ) - assert job is not None gapic_job = job._gca_resource assert ( @@ -246,8 +314,6 @@ def test_mdm_invalid_config_incorrect_model_id(self): project=e2e_base._PROJECT, location=e2e_base._LOCATION, endpoint=self.endpoint, - predict_instance_schema_uri="", - analysis_instance_schema_uri="", deployed_model_ids=[""], ) assert "Invalid model ID" in str(e.value) @@ -265,8 +331,6 @@ def test_mdm_invalid_config_xai(self): project=e2e_base._PROJECT, location=e2e_base._LOCATION, endpoint=self.endpoint, - predict_instance_schema_uri="", - analysis_instance_schema_uri="", ) assert ( "`explanation_config` should only be enabled if the model has `explanation_spec populated" @@ -294,8 +358,6 @@ def test_mdm_two_models_invalid_configs_xai(self): project=e2e_base._PROJECT, location=e2e_base._LOCATION, endpoint=self.endpoint, - predict_instance_schema_uri="", - analysis_instance_schema_uri="", ) assert ( "`explanation_config` should only be enabled if the model has `explanation_spec populated" diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 7381481b30..9d788df262 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -38,11 +38,16 @@ job_state as gca_job_state_compat, machine_resources as gca_machine_resources_compat, manual_batch_tuning_parameters as gca_manual_batch_tuning_parameters_compat, + model_deployment_monitoring_job as gca_model_deployment_monitoring_job_compat, + model_monitoring as gca_model_monitoring_compat, ) from google.cloud.aiplatform.compat.services import ( job_service_client, ) +from google.protobuf import field_mask_pb2 # type: ignore + +from test_endpoints import get_endpoint_with_models_mock # noqa: F401 _TEST_API_CLIENT = job_service_client.JobServiceClient @@ -84,6 +89,11 @@ f"bq://{_TEST_BATCH_PREDICTION_BQ_PREFIX}" ) +_TEST_MDM_JOB_NAME = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/modelDeploymentMonitoringJobs/{_TEST_ID}" +_TEST_ENDPOINT = ( + f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/endpoints/{_TEST_ID}" +) + _TEST_JOB_STATE_SUCCESS = gca_job_state_compat.JobState(4) _TEST_JOB_STATE_RUNNING = gca_job_state_compat.JobState(3) _TEST_JOB_STATE_PENDING = gca_job_state_compat.JobState(2) @@ -164,6 +174,8 @@ _TEST_JOB_DELETE_METHOD_NAME = "delete_custom_job" _TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}" +_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01} + # TODO(b/171333554): Move reusable test fixtures to conftest.py file @@ -969,3 +981,90 @@ def test_batch_predict_job_with_versioned_model( ].model == _TEST_VERSIONED_MODEL_NAME ) + + +@pytest.fixture +def get_mdm_job_mock(): + with mock.patch.object( + _TEST_API_CLIENT, "get_model_deployment_monitoring_job" + ) as get_mdm_job_mock: + get_mdm_job_mock.return_value = ( + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( + name=_TEST_MDM_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_RUNNING, + endpoint=_TEST_ENDPOINT, + ) + ) + yield get_mdm_job_mock + + +@pytest.fixture +@pytest.mark.usefixtures("get_mdm_job_mock") +def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 + with mock.patch.object( + _TEST_API_CLIENT, "update_model_deployment_monitoring_job" + ) as update_mdm_job_mock: + expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( + prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( + drift_thresholds={ + "TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01) + } + ) + ) + all_configs = [] + for model in get_endpoint_with_models_mock.return_value.deployed_models: + all_configs.append( + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( + deployed_model_id=model.id, + objective_config=expected_objective_config, + ) + ) + + update_mdm_job_mock.return_vaue.result_type = ( + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( + name=_TEST_MDM_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + state=_TEST_JOB_STATE_RUNNING, + endpoint=_TEST_ENDPOINT, + model_deployment_monitoring_objective_configs=all_configs, + ) + ) + yield update_mdm_job_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestModelDeploymentMonitoringJob: + def setup_method(self): + reload(initializer) + reload(aiplatform) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): + job = jobs.ModelDeploymentMonitoringJob( + model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME + ) + drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig( + drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG + ) + new_config = aiplatform.model_monitoring.ObjectiveConfig( + drift_detection_config=drift_detection_config + ) + job.update(objective_configs=new_config) + assert ( + job._gca_resource.model_deployment_monitoring_objective_configs[ + 0 + ].objective_config.prediction_drift_detection_config + == drift_detection_config.as_proto() + ) + get_mdm_job_mock.assert_called_with( + name=_TEST_MDM_JOB_NAME, + ) + update_mdm_job_mock.assert_called_once_with( + model_deployment_monitoring_job=get_mdm_job_mock.return_value, + update_mask=field_mask_pb2.FieldMask( + paths=["model_deployment_monitoring_objective_configs"] + ), + )