feat: support raw_predict for Endpoint#1620
Conversation
nayaknishant
left a comment
There was a problem hiding this comment.
Is test_model_interactions.py testing permanent models? If it is I think we should keep it as two separate files (as you have it), but if it's using temp models created in test_model_upload.py wouldn't it make more sense to have a predict() call and then a raw_predict() call after in the same file (maybe change the name to a more broad test_model.py ?
| return Prediction( | ||
| predictions=response_text["predictions"], | ||
| deployed_model_id=raw_predict_response.headers[ | ||
| "X-Vertex-AI-Deployed-Model-Id" |
There was a problem hiding this comment.
Interesting, if you can clarify @rosiezou, I'm guessing X-Vertex-AI-Deployed-Model-Id is populated when the response is returned?
There was a problem hiding this comment.
There was a problem hiding this comment.
I also added constants (_RAW_PREDICT_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id" and _RAW_PREDICT_MODEL_RESOURCE_KEY = "X-Vertex-AI-Model") to replace the hardcoded strings
nayaknishant
left a comment
There was a problem hiding this comment.
I see that system tests have been added, but can we add also unit tests to https://github.com/googleapis/python-aiplatform/blob/main/tests/unit/aiplatform/test_models.py for raw_predict() to ensure the function works as intended during future development.
| from google.api_core import operation | ||
| from google.api_core import exceptions as api_exceptions | ||
| from google.auth import credentials as auth_credentials | ||
| from google.auth.transport.requests import AuthorizedSession |
There was a problem hiding this comment.
Import module instead of class. Because of the import name conflict maybe something like from google.auth.transport import requests as google_auth_requests.
There was a problem hiding this comment.
updated import statement
| timeout (float): Optional. The timeout for this request in seconds. | ||
| use_raw_predict (bool): | ||
| Optional. If set to True, the underlying prediction call will be made | ||
| against Endpoint.raw_predict(). Currently, model version information will |
There was a problem hiding this comment.
- Maybe mention that the default is
Falsein the docstring. - "Currently..." <- does this mean that later it will become available?
There was a problem hiding this comment.
I will remove "currently". I have checked with the service team and they mentioned that it's not on the roadmap to make model version info available for raw_predict due to performance reasons.
There was a problem hiding this comment.
updated doc strings
| return Prediction( | ||
| predictions=response_text["predictions"], | ||
| deployed_model_id=raw_predict_response.headers[ | ||
| "X-Vertex-AI-Deployed-Model-Id" |
There was a problem hiding this comment.
| body=json.dumps({"instances": instances, "parameters": parameters}), | ||
| headers={"Content-Type": "application/json"}, | ||
| ) | ||
| response_text = json.loads(raw_predict_response.text) |
There was a problem hiding this comment.
After calling json.loads this is no longer a text. Please rename the variable response_text to something else.
There was a problem hiding this comment.
updated variable name
| body (bytes): | ||
| Required. The body of the prediction request in bytes. This must not exceed 1.5 mb per request. | ||
| headers (Dict[str, str]): | ||
| Required. The header of the request as a dictionary. There are no restrictions on the header. |
There was a problem hiding this comment.
There is a default value so this is not actually required, contrary to what the docstring says.
On the other hand the body argument probably should not have a default?
There was a problem hiding this comment.
You're right, neither of these args should have any defaults. I've updated the function signatures and doc strings.
| model_resource_name=prediction_response.model, | ||
| ) | ||
| def raw_predict( | ||
| self, body: bytes = None, headers: Dict[str, str] = None |
There was a problem hiding this comment.
If we are going to allow None, the typing annotation should be Optional[Dict[str, str]].
| ) | ||
|
|
||
| def raw_predict( | ||
| self, body: bytes = None, headers: Dict[str, str] = None |
There was a problem hiding this comment.
Same comments about the arguments' default values here.
| ) -> requests.models.Response: | ||
| """Make a prediction request using arbitrary headers. | ||
| This method must be called within the network the PrivateEndpoint is peered to. | ||
| The function call will fail otherwise. To check, use `PrivateEndpoint.network`. |
There was a problem hiding this comment.
Do we know the specific error that would be raised? If so perhaps add a Raises: section to the docstring: https://google.github.io/styleguide/pyguide.html#doc-function-raises
There was a problem hiding this comment.
I'll test out a few scenarios, but most common one will be auth errors raised from google.api_core.exceptions with error code 401
There was a problem hiding this comment.
I meant specifically the part “the function call will fail” in the docstring.
There was a problem hiding this comment.
Got it. I updated the doc strings in PrivateEndpoint.predict and PrivateEndpoint.raw_predict to contain information about the error code. Most common error code will be 404 with a message saying "request not found".
| {"Content-Type": "application/json"}, | ||
| ) | ||
| assert raw_prediction_response.status_code == 200 | ||
| assert len(json.loads(raw_prediction_response.text).items()) == 1 |
There was a problem hiding this comment.
.items() is not needed if we are only checking the dictionary's length.
| _DEFAULT_MACHINE_TYPE = "n1-standard-2" | ||
| _DEPLOYING_MODEL_TRAFFIC_SPLIT_KEY = "0" | ||
| _SUCCESSFUL_HTTP_RESPONSE = 300 | ||
| _RAW_PREDICT_MODEL_ID_KEY = "X-Vertex-AI-Deployed-Model-Id" |
There was a problem hiding this comment.
Probably call this ”deployed model id” instead of “model id” (which sometimes refers to the last part of the resource full name).
There was a problem hiding this comment.
updated variable name
| self.authorized_session = None | ||
| self.raw_predict_request_url = None |
There was a problem hiding this comment.
Should these attributes be public?
| "_authorized_session", | ||
| "_raw_predict_request_url", |
There was a problem hiding this comment.
These attribute names start with underscores, but the attributes in models.py do not have leading underscores. Is this correct?
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly:
Fixes #<issue_number_goes_here> 🦕