diff --git a/tests/unit/vertexai/genai/test_evals.py b/tests/unit/vertexai/genai/test_evals.py index 92e463ccf5..c384cfd40a 100644 --- a/tests/unit/vertexai/genai/test_evals.py +++ b/tests/unit/vertexai/genai/test_evals.py @@ -3624,8 +3624,7 @@ def test_run_inference_with_litellm_string_prompt_format( ) as mock_litellm, mock.patch( "vertexai._genai._evals_common._call_litellm_completion" ) as mock_call_litellm_completion: - # fmt: on - mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"] + mock_litellm.get_llm_provider.return_value = ("gpt-4o", "openai", None , None) prompt_df = pd.DataFrame([{"prompt": "What is LiteLLM?"}]) expected_messages = [{"role": "user", "content": "What is LiteLLM?"}] @@ -3676,17 +3675,18 @@ def test_run_inference_with_litellm_openai_request_format( mock_api_client_fixture, ): """Tests inference with LiteLLM where the row contains a chat completion request body.""" - # fmt: off with ( - mock.patch( - "vertexai._genai._evals_common.litellm" - ) as mock_litellm, + mock.patch("vertexai._genai._evals_common.litellm") as mock_litellm, mock.patch( "vertexai._genai._evals_common._call_litellm_completion" ) as mock_call_litellm_completion, ): - # fmt: on - mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"] + mock_litellm.get_llm_provider.return_value = ( + "gpt-4o", + "openai", + None, + None, + ) prompt_df = pd.DataFrame( [ { @@ -3755,7 +3755,9 @@ def test_run_inference_with_unsupported_model_string( with mock.patch( "vertexai._genai._evals_common.litellm" ) as mock_litellm_package: - mock_litellm_package.utils.get_valid_models.return_value = [] + mock_litellm_package.get_llm_provider.side_effect = ValueError( + "unsupported model" + ) evals_module = evals.Evals(api_client_=mock_api_client_fixture) prompt_df = pd.DataFrame([{"prompt": "test"}]) @@ -3822,7 +3824,7 @@ def test_run_inference_with_litellm_parsing( # fmt: off with mock.patch("vertexai._genai._evals_common.litellm") as mock_litellm: # fmt: on - mock_litellm.utils.get_valid_models.return_value = ["gpt-4o"] + mock_litellm.get_llm_provider.return_value = ("gpt-4o", "openai", None , None) inference_result = self.client.evals.run_inference( model="gpt-4o", src=mock_df, diff --git a/vertexai/_genai/_evals_common.py b/vertexai/_genai/_evals_common.py index 6aff133733..2b37abf04e 100644 --- a/vertexai/_genai/_evals_common.py +++ b/vertexai/_genai/_evals_common.py @@ -735,7 +735,14 @@ def _is_litellm_vertex_maas_model(model: str) -> bool: def _is_litellm_model(model: str) -> bool: """Checks if the model name corresponds to a valid LiteLLM model name.""" - return model in litellm.utils.get_valid_models(model) + if litellm is None: + return False + + try: + litellm.get_llm_provider(model) + return True + except ValueError: + return False def _is_gemini_model(model: str) -> bool: