From 77e4bd89ec2825b926a5556721a677405f796ef5 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Wed, 5 Oct 2022 13:18:18 -0700 Subject: [PATCH 01/17] fix: added proto message conversion to MDMJob.update fields --- google/cloud/aiplatform/jobs.py | 13 +++++----- tests/unit/aiplatform/test_jobs.py | 41 ++++++++++++++++++++++++++++-- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index af3815e47f..c6fd25f712 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2427,7 +2427,8 @@ def update( are allowed. See https://goo.gl/xmQnxf for more information and examples of labels. bigquery_tables_log_ttl (int): - Optional. The TTL(time to live) of BigQuery tables in user projects + Optional. The number of days for which the logs are stored. + The TTL(time to live) of BigQuery tables in user projects which stores logs. A day is the basic unit of the TTL and we take the ceil of TTL/86400(a day). e.g. { second: 3600} indicates ttl = 1 @@ -2462,19 +2463,19 @@ def update( current_job.display_name = display_name if schedule_config is not None: update_mask.append("model_deployment_monitoring_schedule_config") - current_job.model_deployment_monitoring_schedule_config = schedule_config + current_job.model_deployment_monitoring_schedule_config = schedule_config.as_proto() if alert_config is not None: update_mask.append("model_monitoring_alert_config") - current_job.model_monitoring_alert_config = alert_config + current_job.model_monitoring_alert_config = alert_config.as_proto() if logging_sampling_strategy is not None: update_mask.append("logging_sampling_strategy") - current_job.logging_sampling_strategy = logging_sampling_strategy + current_job.logging_sampling_strategy = logging_sampling_strategy.as_proto() if labels is not None: update_mask.append("labels") - current_job.lables = labels + current_job.labels = labels if bigquery_tables_log_ttl is not None: update_mask.append("log_ttl") - current_job.log_ttl = bigquery_tables_log_ttl + current_job.log_ttl = duration_pb2.Duration(seconds = bigquery_tables_log_ttl * 86400) if enable_monitoring_pipeline_logs is not None: update_mask.append("enable_monitoring_pipeline_logs") current_job.enable_monitoring_pipeline_logs = ( diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 9d788df262..c80eb1ad94 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -1049,10 +1049,39 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig( drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG ) + schedule_config = aiplatform.model_monitoring.ScheduleConfig( + monitor_interval = 1 + ) + alert_config = aiplatform.model_monitoring.EmailAlertConfig( + user_emails = ["TEST"] + ) + sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig( + sample_rate = 0.5 + ) + labels = {"TEST KEY":"TEST VAL"} + log_ttl = 1 + display_name = "NEW_NAME" new_config = aiplatform.model_monitoring.ObjectiveConfig( drift_detection_config=drift_detection_config ) - job.update(objective_configs=new_config) + job.update( + display_name = display_name, + schedule_config = schedule_config, + alert_config = alert_config, + logging_sampling_strategy = sampling_strategy, + labels = labels, + bigquery_tables_log_ttl = log_ttl, + enable_monitoring_pipeline_logs = True, + objective_configs=new_config + ) + gapic_job = job._gca_resource + assert(gapic_job.display_name == display_name) + assert(gapic_job.logging_sampling_strategy == sampling_strategy.as_proto()) + assert(gapic_job.model_deployment_monitoring_schedule_config == schedule_config.as_proto()) + assert(gapic_job.labels == labels) + assert(gapic_job.model_monitoring_alert_config == alert_config.as_proto()) + assert(gapic_job.log_ttl.days == 1) + assert(gapic_job.enable_monitoring_pipeline_logs) assert ( job._gca_resource.model_deployment_monitoring_objective_configs[ 0 @@ -1065,6 +1094,14 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): update_mdm_job_mock.assert_called_once_with( model_deployment_monitoring_job=get_mdm_job_mock.return_value, update_mask=field_mask_pb2.FieldMask( - paths=["model_deployment_monitoring_objective_configs"] + paths=[ + "display_name", + "model_deployment_monitoring_schedule_config", + "model_monitoring_alert_config", + "logging_sampling_strategy", + "labels", + "log_ttl", + "enable_monitoring_pipeline_logs", + "model_deployment_monitoring_objective_configs"] ), ) From cca9b0d338879e70a584bdf2322f77c04daa62aa Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Wed, 5 Oct 2022 20:22:40 +0000 Subject: [PATCH 02/17] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- google/cloud/aiplatform/jobs.py | 8 ++++-- tests/unit/aiplatform/test_jobs.py | 46 ++++++++++++++++-------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index c6fd25f712..513df3ab52 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2463,7 +2463,9 @@ def update( current_job.display_name = display_name if schedule_config is not None: update_mask.append("model_deployment_monitoring_schedule_config") - current_job.model_deployment_monitoring_schedule_config = schedule_config.as_proto() + current_job.model_deployment_monitoring_schedule_config = ( + schedule_config.as_proto() + ) if alert_config is not None: update_mask.append("model_monitoring_alert_config") current_job.model_monitoring_alert_config = alert_config.as_proto() @@ -2475,7 +2477,9 @@ def update( current_job.labels = labels if bigquery_tables_log_ttl is not None: update_mask.append("log_ttl") - current_job.log_ttl = duration_pb2.Duration(seconds = bigquery_tables_log_ttl * 86400) + current_job.log_ttl = duration_pb2.Duration( + seconds=bigquery_tables_log_ttl * 86400 + ) if enable_monitoring_pipeline_logs is not None: update_mask.append("enable_monitoring_pipeline_logs") current_job.enable_monitoring_pipeline_logs = ( diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index c80eb1ad94..9495165f62 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -1049,39 +1049,40 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig( drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG ) - schedule_config = aiplatform.model_monitoring.ScheduleConfig( - monitor_interval = 1 - ) + schedule_config = aiplatform.model_monitoring.ScheduleConfig(monitor_interval=1) alert_config = aiplatform.model_monitoring.EmailAlertConfig( - user_emails = ["TEST"] + user_emails=["TEST"] ) sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig( - sample_rate = 0.5 + sample_rate=0.5 ) - labels = {"TEST KEY":"TEST VAL"} + labels = {"TEST KEY": "TEST VAL"} log_ttl = 1 display_name = "NEW_NAME" new_config = aiplatform.model_monitoring.ObjectiveConfig( drift_detection_config=drift_detection_config ) job.update( - display_name = display_name, - schedule_config = schedule_config, - alert_config = alert_config, - logging_sampling_strategy = sampling_strategy, - labels = labels, - bigquery_tables_log_ttl = log_ttl, - enable_monitoring_pipeline_logs = True, - objective_configs=new_config + display_name=display_name, + schedule_config=schedule_config, + alert_config=alert_config, + logging_sampling_strategy=sampling_strategy, + labels=labels, + bigquery_tables_log_ttl=log_ttl, + enable_monitoring_pipeline_logs=True, + objective_configs=new_config, ) gapic_job = job._gca_resource - assert(gapic_job.display_name == display_name) - assert(gapic_job.logging_sampling_strategy == sampling_strategy.as_proto()) - assert(gapic_job.model_deployment_monitoring_schedule_config == schedule_config.as_proto()) - assert(gapic_job.labels == labels) - assert(gapic_job.model_monitoring_alert_config == alert_config.as_proto()) - assert(gapic_job.log_ttl.days == 1) - assert(gapic_job.enable_monitoring_pipeline_logs) + assert gapic_job.display_name == display_name + assert gapic_job.logging_sampling_strategy == sampling_strategy.as_proto() + assert ( + gapic_job.model_deployment_monitoring_schedule_config + == schedule_config.as_proto() + ) + assert gapic_job.labels == labels + assert gapic_job.model_monitoring_alert_config == alert_config.as_proto() + assert gapic_job.log_ttl.days == 1 + assert gapic_job.enable_monitoring_pipeline_logs assert ( job._gca_resource.model_deployment_monitoring_objective_configs[ 0 @@ -1102,6 +1103,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): "labels", "log_ttl", "enable_monitoring_pipeline_logs", - "model_deployment_monitoring_objective_configs"] + "model_deployment_monitoring_objective_configs", + ] ), ) From 2ff747ae7a8b35fd0271b57a8bbb74564531b13b Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Thu, 6 Oct 2022 13:44:51 -0700 Subject: [PATCH 03/17] addressed PR comment --- tests/unit/aiplatform/test_jobs.py | 56 ++++++++++++------------------ 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 9495165f62..12082c9e9f 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -46,6 +46,7 @@ job_service_client, ) from google.protobuf import field_mask_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore from test_endpoints import get_endpoint_with_models_mock # noqa: F401 @@ -175,6 +176,11 @@ _TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}" _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01} +_TEST_MDM_USER_EMAIL = "TEST_EMAIL" +_TEST_MDM_SAMPLE_RATE = 0.5 +_TEST_MDM_LABEL = {"TEST KEY":"TEST VAL"} +_TEST_LOG_TTL_IN_DAYS = 1 +_TEST_MDM_NEW_NAME = "NEW_NAME" # TODO(b/171333554): Move reusable test fixtures to conftest.py file @@ -992,7 +998,6 @@ def get_mdm_job_mock(): gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( name=_TEST_MDM_JOB_NAME, display_name=_TEST_DISPLAY_NAME, - state=_TEST_JOB_STATE_RUNNING, endpoint=_TEST_ENDPOINT, ) ) @@ -1000,36 +1005,21 @@ def get_mdm_job_mock(): @pytest.fixture -@pytest.mark.usefixtures("get_mdm_job_mock") def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 with mock.patch.object( _TEST_API_CLIENT, "update_model_deployment_monitoring_job" ) as update_mdm_job_mock: - expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( - prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( - drift_thresholds={ - "TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01) - } - ) - ) - all_configs = [] - for model in get_endpoint_with_models_mock.return_value.deployed_models: - all_configs.append( - gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( - deployed_model_id=model.id, - objective_config=expected_objective_config, - ) - ) - - update_mdm_job_mock.return_vaue.result_type = ( - gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - name=_TEST_MDM_JOB_NAME, - display_name=_TEST_DISPLAY_NAME, - state=_TEST_JOB_STATE_RUNNING, - endpoint=_TEST_ENDPOINT, - model_deployment_monitoring_objective_configs=all_configs, - ) + expected_output = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( + display_name = "NEW_NAME", + endpoint = _TEST_ENDPOINT, + model_deployment_monitoring_objective_configs=[gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(deployed_model_id = model_id, objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(prediction_drift_detection_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(drift_thresholds = _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG))) for model_id in get_endpoint_with_models_mock.deployed_models], + logging_sampling_strategy = gca_model_monitoring_compat.SamplingStrategy(random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig(sample_rate = _TEST_MDM_SAMPLE_RATE)), + labels = _TEST_MDM_LABEL, + model_monitoring_alert_config = gca_model_monitoring_compat.ModelMonitoringAlertConfig(email_alert_config = gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig(user_emails = [_TEST_MDM_USER_EMAIL])), + log_ttl = duration_pb2.Duration(seconds = _TEST_LOG_TTL_IN_DAYS * 86400), + enable_monitoring_pipeline_logs = True ) + update_mdm_job_mock.return_value.result_type = expected_output yield update_mdm_job_mock @@ -1051,14 +1041,14 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): ) schedule_config = aiplatform.model_monitoring.ScheduleConfig(monitor_interval=1) alert_config = aiplatform.model_monitoring.EmailAlertConfig( - user_emails=["TEST"] + user_emails=[_TEST_MDM_USER_EMAIL] ) sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig( - sample_rate=0.5 + sample_rate=_TEST_MDM_SAMPLE_RATE ) - labels = {"TEST KEY": "TEST VAL"} - log_ttl = 1 - display_name = "NEW_NAME" + labels = _TEST_MDM_LABEL + log_ttl = _TEST_LOG_TTL_IN_DAYS + display_name = _TEST_MDM_NEW_NAME new_config = aiplatform.model_monitoring.ObjectiveConfig( drift_detection_config=drift_detection_config ) @@ -1081,7 +1071,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): ) assert gapic_job.labels == labels assert gapic_job.model_monitoring_alert_config == alert_config.as_proto() - assert gapic_job.log_ttl.days == 1 + assert gapic_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS assert gapic_job.enable_monitoring_pipeline_logs assert ( job._gca_resource.model_deployment_monitoring_objective_configs[ @@ -1093,7 +1083,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): name=_TEST_MDM_JOB_NAME, ) update_mdm_job_mock.assert_called_once_with( - model_deployment_monitoring_job=get_mdm_job_mock.return_value, + model_deployment_monitoring_job=gapic_job, update_mask=field_mask_pb2.FieldMask( paths=[ "display_name", From d53f7fc849fed90e1d049c4f66b9bb3c61c05c07 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Thu, 6 Oct 2022 20:46:21 +0000 Subject: [PATCH 04/17] formatting --- tests/unit/aiplatform/test_jobs.py | 38 ++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 12082c9e9f..e27341758b 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -46,7 +46,7 @@ job_service_client, ) from google.protobuf import field_mask_pb2 # type: ignore -from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore from test_endpoints import get_endpoint_with_models_mock # noqa: F401 @@ -178,7 +178,7 @@ _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01} _TEST_MDM_USER_EMAIL = "TEST_EMAIL" _TEST_MDM_SAMPLE_RATE = 0.5 -_TEST_MDM_LABEL = {"TEST KEY":"TEST VAL"} +_TEST_MDM_LABEL = {"TEST KEY": "TEST VAL"} _TEST_LOG_TTL_IN_DAYS = 1 _TEST_MDM_NEW_NAME = "NEW_NAME" @@ -1010,14 +1010,32 @@ def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 _TEST_API_CLIENT, "update_model_deployment_monitoring_job" ) as update_mdm_job_mock: expected_output = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - display_name = "NEW_NAME", - endpoint = _TEST_ENDPOINT, - model_deployment_monitoring_objective_configs=[gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(deployed_model_id = model_id, objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(prediction_drift_detection_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(drift_thresholds = _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG))) for model_id in get_endpoint_with_models_mock.deployed_models], - logging_sampling_strategy = gca_model_monitoring_compat.SamplingStrategy(random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig(sample_rate = _TEST_MDM_SAMPLE_RATE)), - labels = _TEST_MDM_LABEL, - model_monitoring_alert_config = gca_model_monitoring_compat.ModelMonitoringAlertConfig(email_alert_config = gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig(user_emails = [_TEST_MDM_USER_EMAIL])), - log_ttl = duration_pb2.Duration(seconds = _TEST_LOG_TTL_IN_DAYS * 86400), - enable_monitoring_pipeline_logs = True + display_name="NEW_NAME", + endpoint=_TEST_ENDPOINT, + model_deployment_monitoring_objective_configs=[ + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( + deployed_model_id=model_id, + objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( + prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( + drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG + ) + ), + ) + for model_id in get_endpoint_with_models_mock.deployed_models + ], + logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy( + random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig( + sample_rate=_TEST_MDM_SAMPLE_RATE + ) + ), + labels=_TEST_MDM_LABEL, + model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig( + email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig( + user_emails=[_TEST_MDM_USER_EMAIL] + ) + ), + log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), + enable_monitoring_pipeline_logs=True, ) update_mdm_job_mock.return_value.result_type = expected_output yield update_mdm_job_mock From 07a2c3dc45053726777273c79453e521c11a0f0c Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Thu, 6 Oct 2022 20:47:56 +0000 Subject: [PATCH 05/17] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/unit/aiplatform/test_jobs.py | 38 ++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 12082c9e9f..e27341758b 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -46,7 +46,7 @@ job_service_client, ) from google.protobuf import field_mask_pb2 # type: ignore -from google.protobuf import duration_pb2 # type: ignore +from google.protobuf import duration_pb2 # type: ignore from test_endpoints import get_endpoint_with_models_mock # noqa: F401 @@ -178,7 +178,7 @@ _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01} _TEST_MDM_USER_EMAIL = "TEST_EMAIL" _TEST_MDM_SAMPLE_RATE = 0.5 -_TEST_MDM_LABEL = {"TEST KEY":"TEST VAL"} +_TEST_MDM_LABEL = {"TEST KEY": "TEST VAL"} _TEST_LOG_TTL_IN_DAYS = 1 _TEST_MDM_NEW_NAME = "NEW_NAME" @@ -1010,14 +1010,32 @@ def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 _TEST_API_CLIENT, "update_model_deployment_monitoring_job" ) as update_mdm_job_mock: expected_output = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - display_name = "NEW_NAME", - endpoint = _TEST_ENDPOINT, - model_deployment_monitoring_objective_configs=[gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(deployed_model_id = model_id, objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(prediction_drift_detection_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(drift_thresholds = _TEST_MDM_JOB_DRIFT_DETECTION_CONFIG))) for model_id in get_endpoint_with_models_mock.deployed_models], - logging_sampling_strategy = gca_model_monitoring_compat.SamplingStrategy(random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig(sample_rate = _TEST_MDM_SAMPLE_RATE)), - labels = _TEST_MDM_LABEL, - model_monitoring_alert_config = gca_model_monitoring_compat.ModelMonitoringAlertConfig(email_alert_config = gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig(user_emails = [_TEST_MDM_USER_EMAIL])), - log_ttl = duration_pb2.Duration(seconds = _TEST_LOG_TTL_IN_DAYS * 86400), - enable_monitoring_pipeline_logs = True + display_name="NEW_NAME", + endpoint=_TEST_ENDPOINT, + model_deployment_monitoring_objective_configs=[ + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( + deployed_model_id=model_id, + objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( + prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( + drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG + ) + ), + ) + for model_id in get_endpoint_with_models_mock.deployed_models + ], + logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy( + random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig( + sample_rate=_TEST_MDM_SAMPLE_RATE + ) + ), + labels=_TEST_MDM_LABEL, + model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig( + email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig( + user_emails=[_TEST_MDM_USER_EMAIL] + ) + ), + log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), + enable_monitoring_pipeline_logs=True, ) update_mdm_job_mock.return_value.result_type = expected_output yield update_mdm_job_mock From f8c16cc70b88ca810aad0d419188bd0e42227ea2 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Fri, 7 Oct 2022 13:33:23 -0700 Subject: [PATCH 06/17] replaced string literal with constant --- tests/unit/aiplatform/test_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index e27341758b..4d58ee196c 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -1010,7 +1010,7 @@ def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 _TEST_API_CLIENT, "update_model_deployment_monitoring_job" ) as update_mdm_job_mock: expected_output = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - display_name="NEW_NAME", + display_name=_TEST_MDM_NEW_NAME, endpoint=_TEST_ENDPOINT, model_deployment_monitoring_objective_configs=[ gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( From fb60c395e342103f9cfe440fa7b852d6cd0937c5 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Tue, 11 Oct 2022 13:50:37 -0700 Subject: [PATCH 07/17] adding _gca_resource re-assignmnet to mdm job class --- google/cloud/aiplatform/jobs.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 513df3ab52..5a65f9270b 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2457,6 +2457,7 @@ def update( current_job = self.api_client.get_model_deployment_monitoring_job( name=self._gca_resource.name ) + mdm_job_name = self._gca_resource.name update_mask: List[str] = [] if display_name is not None: update_mask.append("display_name") @@ -2500,6 +2501,9 @@ def update( model_deployment_monitoring_job=current_job, update_mask=field_mask_pb2.FieldMask(paths=update_mask), ) + self._gca_resource = self.api_client.get_model_deployment_monitoring_job( + name=mdm_job_name + ) return self def pause(self) -> "ModelDeploymentMonitoringJob": From 3ebd054b12a83c8b46d309659caa07a6fe0d3da4 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Mon, 17 Oct 2022 23:50:47 +0000 Subject: [PATCH 08/17] Added side effects in get_mdm_job pytest mock --- tests/unit/aiplatform/test_jobs.py | 107 +++++++++++++++++------------ 1 file changed, 62 insertions(+), 45 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 4d58ee196c..8502ef637a 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -48,6 +48,7 @@ from google.protobuf import field_mask_pb2 # type: ignore from google.protobuf import duration_pb2 # type: ignore +import test_endpoints # noqa: F401 from test_endpoints import get_endpoint_with_models_mock # noqa: F401 _TEST_API_CLIENT = job_service_client.JobServiceClient @@ -182,6 +183,48 @@ _TEST_LOG_TTL_IN_DAYS = 1 _TEST_MDM_NEW_NAME = "NEW_NAME" +_TEST_MDM_OLD_JOB = ( + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( + name=_TEST_MDM_JOB_NAME, + display_name=_TEST_DISPLAY_NAME, + endpoint=_TEST_ENDPOINT, + ) +) + +_TEST_MDM_EXPECTED_NEW_JOB = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( + name=_TEST_MDM_JOB_NAME, + display_name=_TEST_MDM_NEW_NAME, + endpoint=_TEST_ENDPOINT, + model_deployment_monitoring_objective_configs=[ + gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( + deployed_model_id=model_id, + objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( + prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( + drift_thresholds={ + "TEST_KEY": gca_model_monitoring_compat.ThresholdConfig( + value=0.01 + ) + } + ) + ), + ) + for model_id in [model.id for model in test_endpoints._TEST_DEPLOYED_MODELS] + ], + logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy( + random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig( + sample_rate=_TEST_MDM_SAMPLE_RATE + ) + ), + labels=_TEST_MDM_LABEL, + model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig( + email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig( + user_emails=[_TEST_MDM_USER_EMAIL] + ) + ), + log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), + enable_monitoring_pipeline_logs=True, +) + # TODO(b/171333554): Move reusable test fixtures to conftest.py file @@ -994,13 +1037,13 @@ def get_mdm_job_mock(): with mock.patch.object( _TEST_API_CLIENT, "get_model_deployment_monitoring_job" ) as get_mdm_job_mock: - get_mdm_job_mock.return_value = ( - gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - name=_TEST_MDM_JOB_NAME, - display_name=_TEST_DISPLAY_NAME, - endpoint=_TEST_ENDPOINT, - ) - ) + get_mdm_job_mock.side_effect = [ + _TEST_MDM_OLD_JOB, + _TEST_MDM_OLD_JOB, + _TEST_MDM_OLD_JOB, + _TEST_MDM_EXPECTED_NEW_JOB, + _TEST_MDM_EXPECTED_NEW_JOB, + ] yield get_mdm_job_mock @@ -1009,35 +1052,7 @@ def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 with mock.patch.object( _TEST_API_CLIENT, "update_model_deployment_monitoring_job" ) as update_mdm_job_mock: - expected_output = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob( - display_name=_TEST_MDM_NEW_NAME, - endpoint=_TEST_ENDPOINT, - model_deployment_monitoring_objective_configs=[ - gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( - deployed_model_id=model_id, - objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig( - prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig( - drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG - ) - ), - ) - for model_id in get_endpoint_with_models_mock.deployed_models - ], - logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy( - random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig( - sample_rate=_TEST_MDM_SAMPLE_RATE - ) - ), - labels=_TEST_MDM_LABEL, - model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig( - email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig( - user_emails=[_TEST_MDM_USER_EMAIL] - ) - ), - log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), - enable_monitoring_pipeline_logs=True, - ) - update_mdm_job_mock.return_value.result_type = expected_output + update_mdm_job_mock.return_value.result_type = _TEST_MDM_EXPECTED_NEW_JOB yield update_mdm_job_mock @@ -1070,6 +1085,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): new_config = aiplatform.model_monitoring.ObjectiveConfig( drift_detection_config=drift_detection_config ) + old_job = job._gca_resource job.update( display_name=display_name, schedule_config=schedule_config, @@ -1080,17 +1096,18 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): enable_monitoring_pipeline_logs=True, objective_configs=new_config, ) - gapic_job = job._gca_resource - assert gapic_job.display_name == display_name - assert gapic_job.logging_sampling_strategy == sampling_strategy.as_proto() + new_job = job._gca_resource + assert old_job != new_job + assert new_job.display_name == display_name + assert new_job.logging_sampling_strategy == sampling_strategy.as_proto() assert ( - gapic_job.model_deployment_monitoring_schedule_config + new_job.model_deployment_monitoring_schedule_config == schedule_config.as_proto() ) - assert gapic_job.labels == labels - assert gapic_job.model_monitoring_alert_config == alert_config.as_proto() - assert gapic_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS - assert gapic_job.enable_monitoring_pipeline_logs + assert new_job.labels == labels + assert new_job.model_monitoring_alert_config == alert_config.as_proto() + assert new_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS + assert new_job.enable_monitoring_pipeline_logs assert ( job._gca_resource.model_deployment_monitoring_objective_configs[ 0 @@ -1101,7 +1118,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): name=_TEST_MDM_JOB_NAME, ) update_mdm_job_mock.assert_called_once_with( - model_deployment_monitoring_job=gapic_job, + model_deployment_monitoring_job=new_job, update_mask=field_mask_pb2.FieldMask( paths=[ "display_name", From bdc1bceac134c4469cea236cb20bba5204ac2060 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Mon, 17 Oct 2022 17:46:07 -0700 Subject: [PATCH 09/17] fixing side effects --- tests/unit/aiplatform/test_jobs.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 8502ef637a..79eec56300 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -16,6 +16,7 @@ # import pytest +import copy from unittest import mock from importlib import reload @@ -221,6 +222,9 @@ user_emails=[_TEST_MDM_USER_EMAIL] ) ), + model_deployment_monitoring_schedule_config = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig( + monitor_interval = duration_pb2.Duration(seconds=3600) + ), log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), enable_monitoring_pipeline_logs=True, ) @@ -1041,7 +1045,7 @@ def get_mdm_job_mock(): _TEST_MDM_OLD_JOB, _TEST_MDM_OLD_JOB, _TEST_MDM_OLD_JOB, - _TEST_MDM_EXPECTED_NEW_JOB, + _TEST_MDM_OLD_JOB, _TEST_MDM_EXPECTED_NEW_JOB, ] yield get_mdm_job_mock @@ -1069,6 +1073,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): job = jobs.ModelDeploymentMonitoringJob( model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME ) + old_job = copy.deepcopy(job._gca_resource) drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig( drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG ) @@ -1085,7 +1090,6 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): new_config = aiplatform.model_monitoring.ObjectiveConfig( drift_detection_config=drift_detection_config ) - old_job = job._gca_resource job.update( display_name=display_name, schedule_config=schedule_config, From 650166e92ab863c1c79a9711b0610c86d2998c8a Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Tue, 18 Oct 2022 00:46:58 +0000 Subject: [PATCH 10/17] formatting --- tests/unit/aiplatform/test_jobs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 79eec56300..7ae15118d9 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -222,8 +222,8 @@ user_emails=[_TEST_MDM_USER_EMAIL] ) ), - model_deployment_monitoring_schedule_config = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig( - monitor_interval = duration_pb2.Duration(seconds=3600) + model_deployment_monitoring_schedule_config=gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig( + monitor_interval=duration_pb2.Duration(seconds=3600) ), log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), enable_monitoring_pipeline_logs=True, From fcc0638a3518eccad089e3780e3da4c553abae9c Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 18 Oct 2022 00:50:40 +0000 Subject: [PATCH 11/17] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/unit/aiplatform/test_jobs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 79eec56300..7ae15118d9 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -222,8 +222,8 @@ user_emails=[_TEST_MDM_USER_EMAIL] ) ), - model_deployment_monitoring_schedule_config = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig( - monitor_interval = duration_pb2.Duration(seconds=3600) + model_deployment_monitoring_schedule_config=gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig( + monitor_interval=duration_pb2.Duration(seconds=3600) ), log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400), enable_monitoring_pipeline_logs=True, From 6dc5f12c771c05efb4ec304054400b1b887a0980 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Mon, 17 Oct 2022 17:51:57 -0700 Subject: [PATCH 12/17] minor edits to variable names --- tests/unit/aiplatform/test_jobs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 7ae15118d9..dc826cbe0d 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -1113,7 +1113,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): assert new_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS assert new_job.enable_monitoring_pipeline_logs assert ( - job._gca_resource.model_deployment_monitoring_objective_configs[ + new_job.model_deployment_monitoring_objective_configs[ 0 ].objective_config.prediction_drift_detection_config == drift_detection_config.as_proto() From d6d300cb88c232eb0209ebe09490c4dbb9e0d29e Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Tue, 18 Oct 2022 10:44:19 -0700 Subject: [PATCH 13/17] Addressed PR feedback --- google/cloud/aiplatform/jobs.py | 13 ++++++------- tests/unit/aiplatform/test_jobs.py | 9 +++++++-- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index 5a65f9270b..dab42e8a8d 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -64,6 +64,7 @@ gca_job_state.JobState.JOB_STATE_FAILED, gca_job_state.JobState.JOB_STATE_CANCELLED, gca_job_state.JobState.JOB_STATE_PAUSED, + gca_job_state.JobState.JOB_STATE_RUNNING ) _JOB_ERROR_STATES = ( @@ -2454,9 +2455,7 @@ def update( will be applied to all deployed models. """ self._sync_gca_resource() - current_job = self.api_client.get_model_deployment_monitoring_job( - name=self._gca_resource.name - ) + current_job = self._gca_resource mdm_job_name = self._gca_resource.name update_mask: List[str] = [] if display_name is not None: @@ -2497,13 +2496,13 @@ def update( deployed_model_ids=deployed_model_ids, ) ) - self.api_client.update_model_deployment_monitoring_job( + operation = self.api_client.update_model_deployment_monitoring_job( model_deployment_monitoring_job=current_job, update_mask=field_mask_pb2.FieldMask(paths=update_mask), ) - self._gca_resource = self.api_client.get_model_deployment_monitoring_job( - name=mdm_job_name - ) + # TODO: b/254285776 add optional_sync support to model monitoring job + self._block_until_complete() + self._gca_resource = operation.result() return self def pause(self) -> "ModelDeploymentMonitoringJob": diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index dc826cbe0d..491fc8aa07 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -25,6 +25,7 @@ from google.cloud import storage from google.cloud import bigquery +from google.api_core import operation from google.auth import credentials as auth_credentials from google.cloud import aiplatform @@ -189,6 +190,7 @@ name=_TEST_MDM_JOB_NAME, display_name=_TEST_DISPLAY_NAME, endpoint=_TEST_ENDPOINT, + state = _TEST_JOB_STATE_RUNNING ) ) @@ -196,6 +198,7 @@ name=_TEST_MDM_JOB_NAME, display_name=_TEST_MDM_NEW_NAME, endpoint=_TEST_ENDPOINT, + state = _TEST_JOB_STATE_RUNNING, model_deployment_monitoring_objective_configs=[ gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( deployed_model_id=model_id, @@ -1045,7 +1048,6 @@ def get_mdm_job_mock(): _TEST_MDM_OLD_JOB, _TEST_MDM_OLD_JOB, _TEST_MDM_OLD_JOB, - _TEST_MDM_OLD_JOB, _TEST_MDM_EXPECTED_NEW_JOB, ] yield get_mdm_job_mock @@ -1056,7 +1058,9 @@ def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811 with mock.patch.object( _TEST_API_CLIENT, "update_model_deployment_monitoring_job" ) as update_mdm_job_mock: - update_mdm_job_mock.return_value.result_type = _TEST_MDM_EXPECTED_NEW_JOB + update_mdm_job_lro_mock = mock.Mock(operation.Operation) + update_mdm_job_lro_mock.result.return_value = _TEST_MDM_EXPECTED_NEW_JOB + update_mdm_job_mock.return_value = update_mdm_job_lro_mock yield update_mdm_job_mock @@ -1120,6 +1124,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): ) get_mdm_job_mock.assert_called_with( name=_TEST_MDM_JOB_NAME, + retry=base._DEFAULT_RETRY ) update_mdm_job_mock.assert_called_once_with( model_deployment_monitoring_job=new_job, From 0b87db719cf292cba90decc08e79cb4b23651c2e Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 18 Oct 2022 17:49:08 +0000 Subject: [PATCH 14/17] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20?= =?UTF-8?q?post-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- google/cloud/aiplatform/jobs.py | 2 +- tests/unit/aiplatform/test_jobs.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index dab42e8a8d..b839f56ff2 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -64,7 +64,7 @@ gca_job_state.JobState.JOB_STATE_FAILED, gca_job_state.JobState.JOB_STATE_CANCELLED, gca_job_state.JobState.JOB_STATE_PAUSED, - gca_job_state.JobState.JOB_STATE_RUNNING + gca_job_state.JobState.JOB_STATE_RUNNING, ) _JOB_ERROR_STATES = ( diff --git a/tests/unit/aiplatform/test_jobs.py b/tests/unit/aiplatform/test_jobs.py index 491fc8aa07..27a9ddea30 100644 --- a/tests/unit/aiplatform/test_jobs.py +++ b/tests/unit/aiplatform/test_jobs.py @@ -190,7 +190,7 @@ name=_TEST_MDM_JOB_NAME, display_name=_TEST_DISPLAY_NAME, endpoint=_TEST_ENDPOINT, - state = _TEST_JOB_STATE_RUNNING + state=_TEST_JOB_STATE_RUNNING, ) ) @@ -198,7 +198,7 @@ name=_TEST_MDM_JOB_NAME, display_name=_TEST_MDM_NEW_NAME, endpoint=_TEST_ENDPOINT, - state = _TEST_JOB_STATE_RUNNING, + state=_TEST_JOB_STATE_RUNNING, model_deployment_monitoring_objective_configs=[ gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig( deployed_model_id=model_id, @@ -1123,8 +1123,7 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock): == drift_detection_config.as_proto() ) get_mdm_job_mock.assert_called_with( - name=_TEST_MDM_JOB_NAME, - retry=base._DEFAULT_RETRY + name=_TEST_MDM_JOB_NAME, retry=base._DEFAULT_RETRY ) update_mdm_job_mock.assert_called_once_with( model_deployment_monitoring_job=new_job, From 000428b91d12c1447479848a4f1b64cf16ac221b Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Tue, 18 Oct 2022 18:26:14 -0700 Subject: [PATCH 15/17] addressed more PR commentes --- google/cloud/aiplatform/jobs.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index b839f56ff2..d7bd16efce 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -64,7 +64,6 @@ gca_job_state.JobState.JOB_STATE_FAILED, gca_job_state.JobState.JOB_STATE_CANCELLED, gca_job_state.JobState.JOB_STATE_PAUSED, - gca_job_state.JobState.JOB_STATE_RUNNING, ) _JOB_ERROR_STATES = ( @@ -2455,7 +2454,7 @@ def update( will be applied to all deployed models. """ self._sync_gca_resource() - current_job = self._gca_resource + current_job = copy.deepcopy(self._gca_resource) mdm_job_name = self._gca_resource.name update_mask: List[str] = [] if display_name is not None: @@ -2496,13 +2495,11 @@ def update( deployed_model_ids=deployed_model_ids, ) ) - operation = self.api_client.update_model_deployment_monitoring_job( + response = self.api_client.update_model_deployment_monitoring_job( model_deployment_monitoring_job=current_job, update_mask=field_mask_pb2.FieldMask(paths=update_mask), - ) - # TODO: b/254285776 add optional_sync support to model monitoring job - self._block_until_complete() - self._gca_resource = operation.result() + ).result() + self._gca_resource = response return self def pause(self) -> "ModelDeploymentMonitoringJob": From ccbfbf184a6cde2f269fada87bbd217819eb4e59 Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Wed, 19 Oct 2022 10:52:02 -0700 Subject: [PATCH 16/17] addressed PR comments --- google/cloud/aiplatform/jobs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index d7bd16efce..c3f38f42b5 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2495,11 +2495,12 @@ def update( deployed_model_ids=deployed_model_ids, ) ) - response = self.api_client.update_model_deployment_monitoring_job( + # TODO: b/254285776 add optional_sync support to model monitoring job + lro = self.api_client.update_model_deployment_monitoring_job( model_deployment_monitoring_job=current_job, update_mask=field_mask_pb2.FieldMask(paths=update_mask), - ).result() - self._gca_resource = response + ) + self._gca_resource = lro.result() return self def pause(self) -> "ModelDeploymentMonitoringJob": From de5489d35278473f5ef09041a54bfad32bf86a2d Mon Sep 17 00:00:00 2001 From: Rosie Zou Date: Wed, 19 Oct 2022 20:11:07 +0000 Subject: [PATCH 17/17] fix linter errors --- google/cloud/aiplatform/jobs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/google/cloud/aiplatform/jobs.py b/google/cloud/aiplatform/jobs.py index c3f38f42b5..5ab03b75b5 100644 --- a/google/cloud/aiplatform/jobs.py +++ b/google/cloud/aiplatform/jobs.py @@ -2455,7 +2455,6 @@ def update( """ self._sync_gca_resource() current_job = copy.deepcopy(self._gca_resource) - mdm_job_name = self._gca_resource.name update_mask: List[str] = [] if display_name is not None: update_mask.append("display_name")