From 95dfce8b533a26fde520183bf02fd4e6f37a8152 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Thu, 5 May 2022 15:10:34 -0700 Subject: [PATCH 1/4] feat: add experimental GDCH support --- google/auth/_default.py | 41 +++++ google/oauth2/_client.py | 72 ++++++--- google/oauth2/gdch_credentials.py | 213 ++++++++++++++++++++++++++ tests/data/gdch_service_account.json | 10 ++ tests/oauth2/test__client.py | 36 ++++- tests/oauth2/test_gdch_credentials.py | 186 ++++++++++++++++++++++ tests/test__default.py | 35 +++++ 7 files changed, 574 insertions(+), 19 deletions(-) create mode 100644 google/oauth2/gdch_credentials.py create mode 100644 tests/data/gdch_service_account.json create mode 100644 tests/oauth2/test_gdch_credentials.py diff --git a/google/auth/_default.py b/google/auth/_default.py index d038438d5..fbf9d944b 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -36,11 +36,13 @@ _SERVICE_ACCOUNT_TYPE = "service_account" _EXTERNAL_ACCOUNT_TYPE = "external_account" _IMPERSONATED_SERVICE_ACCOUNT_TYPE = "impersonated_service_account" +_GDCH_SERVICE_ACCOUNT_TYPE = "gdch_service_account" _VALID_TYPES = ( _AUTHORIZED_USER_TYPE, _SERVICE_ACCOUNT_TYPE, _EXTERNAL_ACCOUNT_TYPE, _IMPERSONATED_SERVICE_ACCOUNT_TYPE, + _GDCH_SERVICE_ACCOUNT_TYPE, ) # Help message when no credentials can be found. @@ -158,6 +160,8 @@ def _load_credentials_from_info( credentials, project_id = _get_impersonated_service_account_credentials( filename, info, scopes ) + elif credential_type == _GDCH_SERVICE_ACCOUNT_TYPE: + credentials, project_id = _get_gdch_service_account_credentials(info) else: raise exceptions.DefaultCredentialsError( "The file {file} does not have a valid type. " @@ -421,6 +425,36 @@ def _get_impersonated_service_account_credentials(filename, info, scopes): return credentials, None +def _get_gdch_service_account_credentials(info): + from google.oauth2 import gdch_credentials + + k8s_ca_cert_path = info.get("k8s_ca_cert_path") + k8s_cert_path = info.get("k8s_cert_path") + k8s_key_path = info.get("k8s_key_path") + k8s_token_endpoint = info.get("k8s_token_endpoint") + ais_ca_cert_path = info.get("ais_ca_cert_path") + ais_token_endpoint = info.get("ais_token_endpoint") + + format_version = info.get("format_version") + if format_version != "v1": + raise exceptions.DefaultCredentialsError( + "format_version is not provided or unsupported. Supported version is: v1" + ) + + return ( + gdch_credentials.ServiceAccountCredentials( + k8s_ca_cert_path, + k8s_cert_path, + k8s_key_path, + k8s_token_endpoint, + ais_ca_cert_path, + ais_token_endpoint, + None, + ), + None, + ) + + def _apply_quota_project_id(credentials, quota_project_id): if quota_project_id: credentials = credentials.with_quota_project(quota_project_id) @@ -456,6 +490,11 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non endpoint. The project ID returned in this case is the one corresponding to the underlying workload identity pool resource if determinable. + + If the environment variable is set to the path of a valid GDCH service + account JSON file (`Google Distributed Cloud Hosted`_), then a GDCH + credential will be returned. The project ID returned is None unless it + is set via `GOOGLE_CLOUD_PROJECT` environment variable. 2. If the `Google Cloud SDK`_ is installed and has application default credentials set they are loaded and returned. @@ -490,6 +529,8 @@ def default(scopes=None, request=None, quota_project_id=None, default_scopes=Non .. _Metadata Service: https://cloud.google.com/compute/docs\ /storing-retrieving-metadata .. _Cloud Run: https://cloud.google.com/run + .. _Google Distributed Cloud Hosted: https://cloud.google.com/blog/topics\ + /hybrid-cloud/announcing-google-distributed-cloud-edge-and-hosted Example:: diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 2f4e8474b..473e92cbc 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -44,11 +44,13 @@ def _handle_error_response(response_data): """Translates an error response into an exception. Args: - response_data (Mapping): The decoded response data. + response_data (Mapping | str): The decoded response data. Raises: google.auth.exceptions.RefreshError: The errors contained in response_data. """ + if isinstance(response_data, six.string_types): + raise exceptions.RefreshError(response_data) try: error_details = "{}: {}".format( response_data["error"], response_data.get("error_description") @@ -79,7 +81,13 @@ def _parse_expiry(response_data): def _token_endpoint_request_no_throw( - request, token_uri, body, access_token=None, use_json=False + request, + token_uri, + body, + access_token=None, + use_json=False, + expected_status_code=http_client.OK, + **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. This function doesn't throw on response errors. @@ -93,6 +101,10 @@ def _token_endpoint_request_no_throw( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + expected_status_code (Optional(int)): The expected the status code of + the token response. The default value is 200. We may expect other + status code like 201 for GDCH credentials. + kwargs: Additional arguments passed on to the request method. Returns: Tuple(bool, Mapping[str, str]): A boolean indicating if the request is @@ -112,32 +124,46 @@ def _token_endpoint_request_no_throw( # retry to fetch token for maximum of two times if any internal failure # occurs. while True: - response = request(method="POST", url=token_uri, headers=headers, body=body) + response = request( + method="POST", url=token_uri, headers=headers, body=body, **kwargs + ) response_body = ( response.data.decode("utf-8") if hasattr(response.data, "decode") else response.data ) - response_data = json.loads(response_body) - if response.status == http_client.OK: + if response.status == expected_status_code: + # response_body should be a JSON + response_data = json.loads(response_body) break else: - error_desc = response_data.get("error_description") or "" - error_code = response_data.get("error") or "" - if ( - any(e == "internal_failure" for e in (error_code, error_desc)) - and retry < 1 - ): - retry += 1 - continue - return response.status == http_client.OK, response_data - - return response.status == http_client.OK, response_data + # For a failed response, response_body could be a string + try: + response_data = json.loads(response_body) + error_desc = response_data.get("error_description") or "" + error_code = response_data.get("error") or "" + if ( + any(e == "internal_failure" for e in (error_code, error_desc)) + and retry < 1 + ): + retry += 1 + continue + except ValueError: + response_data = response_body + return response.status == expected_status_code, response_data + + return response.status == expected_status_code, response_data def _token_endpoint_request( - request, token_uri, body, access_token=None, use_json=False + request, + token_uri, + body, + access_token=None, + use_json=False, + expected_status_code=http_client.OK, + **kwargs ): """Makes a request to the OAuth 2.0 authorization server's token endpoint. @@ -150,6 +176,10 @@ def _token_endpoint_request( access_token (Optional(str)): The access token needed to make the request. use_json (Optional(bool)): Use urlencoded format or json format for the content type. The default value is False. + expected_status_code (Optional(int)): The expected the status code of + the token response. The default value is 200. We may expect other + status code like 201 for GDCH credentials. + kwargs: Additional arguments passed on to the request method. Returns: Mapping[str, str]: The JSON-decoded response data. @@ -159,7 +189,13 @@ def _token_endpoint_request( an error. """ response_status_ok, response_data = _token_endpoint_request_no_throw( - request, token_uri, body, access_token=access_token, use_json=use_json + request, + token_uri, + body, + access_token=access_token, + use_json=use_json, + expected_status_code=expected_status_code, + **kwargs ) if not response_status_ok: _handle_error_response(response_data) diff --git a/google/oauth2/gdch_credentials.py b/google/oauth2/gdch_credentials.py new file mode 100644 index 000000000..b71b623de --- /dev/null +++ b/google/oauth2/gdch_credentials.py @@ -0,0 +1,213 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experimental GDCH credentials support. +""" + +import six +from six.moves import http_client + +from google.auth import _helpers +from google.auth import credentials +from google.auth import exceptions +from google.oauth2 import _client + + +TOKEN_EXCHANGE_TYPE = "urn:ietf:params:oauth:token-type:token-exchange" +ACCESS_TOKEN_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:access_token" +JWT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" +SERVICE_ACCOUNT_TOKEN_TYPE = "urn:k8s:params:oauth:token-type:serviceaccount" + + +class ServiceAccountCredentials(credentials.CredentialsWithQuotaProject): + """Credentials for GDCH (`Google Distributed Cloud Hosted`_) for service + account users. + + .. _Google Distributed Cloud Hosted: + https://cloud.google.com/blog/topics/hybrid-cloud/\ + announcing-google-distributed-cloud-edge-and-hosted + + Besides the constructor, a GDCH credential can be created via application + default credentials. + + To do so, user first creates a JSON file of the + following format:: + + { + "type":"gdch_service_account", + "format_version":"v1", + "k8s_ca_cert_path":"", + "k8s_cert_path":"", + "k8s_key_path":"", + "k8s_token_endpoint":"", + "ais_ca_cert_path":"", + "ais_token_endpoint":"" + } + + Here "k8s_*" files are used to request a k8s token from k8s token endpoint + using mutual TLS connection. The k8s token is then sent to AIS token endpoint + to exchange for an AIS token. The AIS token will be used to talk to Google + API services. + + "k8s_ca_cert_path" field is not needed if the k8s server uses well known CA. + "ais_ca_cert_path" field is not needed if the AIS server uses well known CA. + These two fields can be used for testing environments. + + The "format_version" field stands for the format of the JSON file. For now + it is always "v1". + + After the JSON file is created, set `GOOGLE_APPLICATION_CREDENTIALS` environment + variable to the JSON file path, then use the following code to create the + credential:: + + import google.auth + + credential, _ = google.auth.default() + credential = credential.with_audience("") + + The audience denotes the scope the AIS token is requested, for example, it + could be either a k8s cluster or API service. + """ + + def __init__( + self, + k8s_ca_cert_path, + k8s_cert_path, + k8s_key_path, + k8s_token_endpoint, + ais_ca_cert_path, + ais_token_endpoint, + audience, + quota_project_id=None, + ): + """ + Args: + k8s_ca_cert_path (str): CA cert path for k8s calls. This field is + useful if the specific k8s server doesn't use well known CA, + for instance, a testing k8s server. If the CA is well known, + you can pass `None` for this parameter. + k8s_cert_path (str): Certificate path for k8s calls + k8s_key_path (str): Key path for k8s calls + k8s_token_endpoint (str): k8s token endpoint url + ais_ca_cert_path (str): CA cert path for AIS token endpoint calls. + This field is useful if the specific AIS token server doesn't + uses well known CA, for instance, a testing AIS server. If the + CA is well known, you can pass `None` for this parameter. + ais_token_endpoint (str): AIS token endpoint url + audience (str): The audience for the requested AIS token. For + example, it could be a k8s cluster or API service. + quota_project_id (Optional[str]): The project ID used for quota + and billing. This project may be different from the project + used to create the credentials. + """ + super(ServiceAccountCredentials, self).__init__() + self._k8s_ca_cert_path = k8s_ca_cert_path + self._k8s_cert_path = k8s_cert_path + self._k8s_key_path = k8s_key_path + self._k8s_token_endpoint = k8s_token_endpoint + self._ais_ca_cert_path = ais_ca_cert_path + self._ais_token_endpoint = ais_token_endpoint + self._audience = audience + self._quota_project_id = quota_project_id + + def _make_k8s_token_request(self, request): + k8s_request_body = { + "kind": "TokenRequest", + "apiVersion": "authentication.k8s.io/v1", + "spec": {"audiences": [self._ais_token_endpoint]}, + } + # mTLS connection to k8s token endpoint to get a k8s token. + k8s_response_data = _client._token_endpoint_request( + request, + self._k8s_token_endpoint, + k8s_request_body, + None, + True, + http_client.CREATED, + cert=(self._k8s_cert_path, self._k8s_key_path), + verify=self._k8s_ca_cert_path, + ) + + try: + k8s_token = k8s_response_data["status"]["token"] + return k8s_token + except KeyError as caught_exc: + new_exc = exceptions.RefreshError( + "No access token in k8s token response.", k8s_response_data + ) + six.raise_from(new_exc, caught_exc) + + def _make_ais_token_request(self, k8s_token, request): + # send a request to AIS token point with the k8s token + ais_request_body = { + "grant_type": TOKEN_EXCHANGE_TYPE, + "audience": self._audience, + "requested_token_type": ACCESS_TOKEN_TOKEN_TYPE, + "subject_token": k8s_token, + "subject_token_type": SERVICE_ACCOUNT_TOKEN_TYPE, + } + ais_response_data = _client._token_endpoint_request( + request, + self._ais_token_endpoint, + ais_request_body, + None, + True, + verify=self._ais_ca_cert_path, + ) + ais_token, _, ais_expiry, _ = _client._handle_refresh_grant_response( + ais_response_data, None + ) + return ais_token, ais_expiry + + @_helpers.copy_docstring(credentials.Credentials) + def refresh(self, request): + import google.auth.transport.requests + + if not isinstance(request, google.auth.transport.requests.Request): + raise exceptions.RefreshError( + "For GDCH service account credentials, request must be a google.auth.transport.requests.Request object" + ) + + k8s_token = self._make_k8s_token_request(request) + self.token, self.expiry = self._make_ais_token_request(k8s_token, request) + + def with_audience(self, audience): + """Create a copy of GDCH credentials with the specified audience. + + Args: + audience (str): The intended audience for GDCH credentials. + """ + return self.__class__( + self._k8s_ca_cert_path, + self._k8s_cert_path, + self._k8s_key_path, + self._k8s_token_endpoint, + self._ais_ca_cert_path, + self._ais_token_endpoint, + audience, + self._quota_project_id, + ) + + @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) + def with_quota_project(self, quota_project_id): + return self.__class__( + self._k8s_ca_cert_path, + self._k8s_cert_path, + self._k8s_key_path, + self._k8s_token_endpoint, + self._ais_ca_cert_path, + self._ais_token_endpoint, + self._audience, + quota_project_id, + ) diff --git a/tests/data/gdch_service_account.json b/tests/data/gdch_service_account.json new file mode 100644 index 000000000..c6c441bfd --- /dev/null +++ b/tests/data/gdch_service_account.json @@ -0,0 +1,10 @@ +{ + "type":"gdch_service_account", + "format_version": "v1", + "k8s_ca_cert_path":"./k8s_ca_cert.pem", + "k8s_cert_path":"./k8s_cert.pem", + "k8s_key_path":"./k8s_key.pem", + "k8s_token_endpoint":"https://k8s_endpoint/api/v1/namespaces/sa-token-test/serviceaccounts/sa-token-user/token", + "ais_ca_cert_path":"./ais_ca_cert.pem", + "ais_token_endpoint":"https://ais_endpoint/sts/v1beta/token" +} diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index 5485bed84..f50ee465b 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -56,7 +56,7 @@ def test__handle_error_response(): assert excinfo.match(r"help: I\'m alive") -def test__handle_error_response_non_json(): +def test__handle_error_response_no_error(): response_data = {"foo": "bar"} with pytest.raises(exceptions.RefreshError) as excinfo: @@ -65,6 +65,15 @@ def test__handle_error_response_non_json(): assert excinfo.match(r"{\"foo\": \"bar\"}") +def test__handle_error_response_not_json(): + response_data = "this is an error message" + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._handle_error_response(response_data) + + assert excinfo.match(response_data) + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test__parse_expiry(unused_utcnow): result = _client._parse_expiry({"expires_in": 500}) @@ -156,6 +165,31 @@ def test__token_endpoint_request_internal_failure_error(): ) +def test__token_endpoint_request_string_error(): + response = mock.create_autospec(transport.Response, instance=True) + response.status = http_client.BAD_REQUEST + response.data = "this is an error message" + request = mock.create_autospec(transport.Request) + request.return_value = response + + with pytest.raises(exceptions.RefreshError) as excinfo: + _client._token_endpoint_request(request, "http://example.com", {}) + assert excinfo.match("this is an error message") + + +def test__token_endpoint_request_expected_status_code(): + request = make_request({}, status=http_client.CREATED) + + # It doesn't throw if the response code is the expected one. + _client._token_endpoint_request( + request, "http://example.com", {}, expected_status_code=http_client.CREATED + ) + + # It throws since the default status code is 200 OK, but we are expecting 201 CREATED. + with pytest.raises(exceptions.RefreshError): + _client._token_endpoint_request(request, "http://example.com", {}) + + def verify_request_params(request, params): request_body = request.call_args[1]["body"].decode("utf-8") request_params = urllib.parse.parse_qs(request_body) diff --git a/tests/oauth2/test_gdch_credentials.py b/tests/oauth2/test_gdch_credentials.py new file mode 100644 index 000000000..075452988 --- /dev/null +++ b/tests/oauth2/test_gdch_credentials.py @@ -0,0 +1,186 @@ +# Copyright 2022 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +import mock +import pytest # type: ignore +from six.moves import http_client + +from google.auth import exceptions +from google.auth.transport import requests +from google.oauth2 import gdch_credentials + + +class TestCredentials(object): + K8S_CA_CERT_PATH = "./k8s_ca_cert.pem" + K8S_CERT_PATH = "./k8s_cert.pem" + K8S_KEY_PATH = "./k8s_key.pem" + K8S_TOKEN_ENDPOINT = "https://k8s_endpoint/v1/token" + AIS_CA_CERT_PATH = "./ais_ca_cert.pem" + AIS_TOKEN_ENDPOINT = "https://k8s_endpoint/v1/token" + AUDIENCE = "audience_foo" + QUOTA_PROJECT = "project_foo" + + @classmethod + def make_credentials(cls): + return gdch_credentials.ServiceAccountCredentials( + cls.K8S_CA_CERT_PATH, + cls.K8S_CERT_PATH, + cls.K8S_KEY_PATH, + cls.K8S_TOKEN_ENDPOINT, + cls.AIS_CA_CERT_PATH, + cls.AIS_TOKEN_ENDPOINT, + cls.AUDIENCE, + cls.QUOTA_PROJECT, + ) + + def test_with_audience(self): + creds = self.make_credentials() + assert creds._audience == self.AUDIENCE + + new_creds = creds.with_audience("bar") + assert new_creds._audience == "bar" + + def test_with_quota_project(self): + creds = self.make_credentials() + assert creds.quota_project_id == self.QUOTA_PROJECT + + new_creds = creds.with_quota_project("project_bar") + assert new_creds._quota_project_id == "project_bar" + + @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) + def test__make_k8s_token_request(self, token_endpoint_request): + creds = self.make_credentials() + req = requests.Request() + + token_endpoint_request.return_value = { + "status": { + "token": "k8s_token", + "expirationTimestamp": "2022-02-22T06:51:46Z", + } + } + assert creds._make_k8s_token_request(req) == "k8s_token" + token_endpoint_request.assert_called_with( + req, + creds._k8s_token_endpoint, + { + "kind": "TokenRequest", + "apiVersion": "authentication.k8s.io/v1", + "spec": {"audiences": [creds._ais_token_endpoint]}, + }, + None, + True, + http_client.CREATED, + cert=(creds._k8s_cert_path, creds._k8s_key_path), + verify=creds._k8s_ca_cert_path, + ) + + @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) + def test__make_k8s_token_request_no_token(self, token_endpoint_request): + creds = self.make_credentials() + req = requests.Request() + + token_endpoint_request.return_value = { + "status": {"expirationTimestamp": "2022-02-22T06:51:46Z"} + } + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds._make_k8s_token_request(req) + assert excinfo.match("No access token in k8s token response") + + @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) + @mock.patch("google.auth._helpers.utcnow", autospec=True) + def test__make_ais_token_request(self, utcnow, token_endpoint_request): + creds = self.make_credentials() + req = requests.Request() + + token_endpoint_request.return_value = { + "access_token": "ais_token", + "expires_in": 3599, + "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", + "token_type": "Bearer", + } + utcnow.return_value = datetime.datetime(2022, 1, 1, 0, 0, 0) + + k8s_token = "k8s_token" + ais_token, ais_expiry = creds._make_ais_token_request(k8s_token, req) + assert ais_token == "ais_token" + assert ais_expiry == datetime.datetime(2022, 1, 1, 0, 59, 59) + token_endpoint_request.assert_called_with( + req, + creds._ais_token_endpoint, + { + "grant_type": gdch_credentials.TOKEN_EXCHANGE_TYPE, + "audience": creds._audience, + "requested_token_type": gdch_credentials.ACCESS_TOKEN_TOKEN_TYPE, + "subject_token": k8s_token, + "subject_token_type": gdch_credentials.SERVICE_ACCOUNT_TOKEN_TYPE, + }, + None, + True, + verify=creds._ais_ca_cert_path, + ) + + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_k8s_token_request", + autospec=True, + ) + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_ais_token_request", + autospec=True, + ) + def test_refresh(self, ais_token_request, k8s_token_request): + k8s_token_request.return_value = "k8s_token" + mock_expiry = mock.Mock() + ais_token_request.return_value = ("ais_token", mock_expiry) + + creds = self.make_credentials() + req = requests.Request() + creds.refresh(req) + + k8s_token_request.assert_called_with(creds, req) + ais_token_request.assert_called_with(creds, "k8s_token", req) + assert creds.token == "ais_token" + assert creds.expiry == mock_expiry + + def test_refresh_request_not_requests_type(self): + creds = self.make_credentials() + req = mock.Mock() + + with pytest.raises(exceptions.RefreshError) as excinfo: + creds.refresh(req) + assert excinfo.match( + "request must be a google.auth.transport.requests.Request object" + ) + + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_k8s_token_request", + autospec=True, + ) + @mock.patch( + "google.oauth2.gdch_credentials.ServiceAccountCredentials._make_ais_token_request", + autospec=True, + ) + def test_before_request(self, ais_token_request, k8s_token_request): + ais_token_request.return_value = ("ais_token", mock.Mock()) + + cred = self.make_credentials() + headers = {} + + cred.before_request(requests.Request(), "GET", "https://example.com", headers) + k8s_token_request.assert_called() + ais_token_request.assert_called() + assert headers["authorization"] == "Bearer ais_token" + assert headers["x-goog-user-project"] == "project_foo" diff --git a/tests/test__default.py b/tests/test__default.py index ed64bc723..93e93b8da 100644 --- a/tests/test__default.py +++ b/tests/test__default.py @@ -28,6 +28,7 @@ from google.auth import external_account from google.auth import identity_pool from google.auth import impersonated_credentials +from google.oauth2 import gdch_credentials from google.oauth2 import service_account import google.oauth2.credentials @@ -50,6 +51,8 @@ CLIENT_SECRETS_FILE = os.path.join(DATA_DIR, "client_secrets.json") +GDCH_SERVICE_ACCOUNT_FILE = os.path.join(DATA_DIR, "gdch_service_account.json") + with open(SERVICE_ACCOUNT_FILE) as fh: SERVICE_ACCOUNT_FILE_DATA = json.load(fh) @@ -637,6 +640,18 @@ def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_proj assert get_project_id.called +def test__get_gdch_service_account_credentials_no_format_version(): + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default._get_gdch_service_account_credentials({}) + assert excinfo.match("format_version is not provided or unsupported") + + +def test__get_gdch_service_account_credentials_invalid_format_version(): + with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: + _default._get_gdch_service_account_credentials({"format_version": "v2"}) + assert excinfo.match("format_version is not provided or unsupported") + + class _AppIdentityModule(object): """The interface of the App Idenity app engine module. See https://cloud.google.com/appengine/docs/standard/python/refdocs\ @@ -1140,3 +1155,23 @@ def test_default_impersonated_service_account_set_both_scopes_and_default_scopes credentials, _ = _default.default(scopes=scopes, default_scopes=default_scopes) assert credentials._target_scopes == scopes + + +@mock.patch( + "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True +) +def test_default_gdch_service_account_credentials(get_adc_path): + get_adc_path.return_value = GDCH_SERVICE_ACCOUNT_FILE + + credentials, _ = _default.default(quota_project_id="project-foo") + assert isinstance(credentials, gdch_credentials.ServiceAccountCredentials) + assert credentials._quota_project_id == "project-foo" + assert credentials._k8s_ca_cert_path == "./k8s_ca_cert.pem" + assert credentials._k8s_cert_path == "./k8s_cert.pem" + assert credentials._k8s_key_path == "./k8s_key.pem" + assert ( + credentials._k8s_token_endpoint + == "https://k8s_endpoint/api/v1/namespaces/sa-token-test/serviceaccounts/sa-token-user/token" + ) + assert credentials._ais_ca_cert_path == "./ais_ca_cert.pem" + assert credentials._ais_token_endpoint == "https://ais_endpoint/sts/v1beta/token" From 5ef70b949fcfed9d927e2764513ce62699c10154 Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Tue, 10 May 2022 01:55:06 -0700 Subject: [PATCH 2/4] address comments --- google/oauth2/_client.py | 18 +++++++++++++++--- google/oauth2/gdch_credentials.py | 10 +++++----- tests/oauth2/test__client.py | 4 ++++ tests/oauth2/test_gdch_credentials.py | 22 ++++++++++++++-------- tests/test__default.py | 8 ++++++-- 5 files changed, 44 insertions(+), 18 deletions(-) diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 473e92cbc..8831baf27 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -104,7 +104,13 @@ def _token_endpoint_request_no_throw( expected_status_code (Optional(int)): The expected the status code of the token response. The default value is 200. We may expect other status code like 201 for GDCH credentials. - kwargs: Additional arguments passed on to the request method. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. Returns: Tuple(bool, Mapping[str, str]): A boolean indicating if the request is @@ -151,7 +157,7 @@ def _token_endpoint_request_no_throw( continue except ValueError: response_data = response_body - return response.status == expected_status_code, response_data + return False, response_data return response.status == expected_status_code, response_data @@ -179,7 +185,13 @@ def _token_endpoint_request( expected_status_code (Optional(int)): The expected the status code of the token response. The default value is 200. We may expect other status code like 201 for GDCH credentials. - kwargs: Additional arguments passed on to the request method. + kwargs: Additional arguments passed on to the request method. The + kwargs will be passed to `requests.request` method, see: + https://docs.python-requests.org/en/latest/api/#requests.request. + For example, you can use `cert=("cert_pem_path", "key_pem_path")` + to set up client side SSL certificate, and use + `verify="ca_bundle_path"` to set up the CA certificates for sever + side SSL certificate verification. Returns: Mapping[str, str]: The JSON-decoded response data. diff --git a/google/oauth2/gdch_credentials.py b/google/oauth2/gdch_credentials.py index b71b623de..0d9e08c38 100644 --- a/google/oauth2/gdch_credentials.py +++ b/google/oauth2/gdch_credentials.py @@ -132,9 +132,9 @@ def _make_k8s_token_request(self, request): request, self._k8s_token_endpoint, k8s_request_body, - None, - True, - http_client.CREATED, + access_token=None, + use_json=True, + expected_status_code=http_client.CREATED, cert=(self._k8s_cert_path, self._k8s_key_path), verify=self._k8s_ca_cert_path, ) @@ -161,8 +161,8 @@ def _make_ais_token_request(self, k8s_token, request): request, self._ais_token_endpoint, ais_request_body, - None, - True, + access_token=None, + use_json=True, verify=self._ais_ca_cert_path, ) ais_token, _, ais_expiry, _ = _client._handle_refresh_grant_response( diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index f50ee465b..400582fc3 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -154,6 +154,8 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error_description": "internal_failure"} ) + # request should be called twice due to the retry + assert request.call_count == 2 request = make_request( {"error": "internal_failure"}, status=http_client.BAD_REQUEST @@ -163,6 +165,8 @@ def test__token_endpoint_request_internal_failure_error(): _client._token_endpoint_request( request, "http://example.com", {"error": "internal_failure"} ) + # request should be called twice due to the retry + assert request.call_count == 2 def test__token_endpoint_request_string_error(): diff --git a/tests/oauth2/test_gdch_credentials.py b/tests/oauth2/test_gdch_credentials.py index 075452988..0754f044c 100644 --- a/tests/oauth2/test_gdch_credentials.py +++ b/tests/oauth2/test_gdch_credentials.py @@ -27,6 +27,7 @@ class TestCredentials(object): K8S_CA_CERT_PATH = "./k8s_ca_cert.pem" K8S_CERT_PATH = "./k8s_cert.pem" K8S_KEY_PATH = "./k8s_key.pem" + K8S_TOKEN = "k8s_token" K8S_TOKEN_ENDPOINT = "https://k8s_endpoint/v1/token" AIS_CA_CERT_PATH = "./ais_ca_cert.pem" AIS_TOKEN_ENDPOINT = "https://k8s_endpoint/v1/token" @@ -67,11 +68,11 @@ def test__make_k8s_token_request(self, token_endpoint_request): token_endpoint_request.return_value = { "status": { - "token": "k8s_token", + "token": self.K8S_TOKEN, "expirationTimestamp": "2022-02-22T06:51:46Z", } } - assert creds._make_k8s_token_request(req) == "k8s_token" + assert creds._make_k8s_token_request(req) == self.K8S_TOKEN token_endpoint_request.assert_called_with( req, creds._k8s_token_endpoint, @@ -106,18 +107,23 @@ def test__make_ais_token_request(self, utcnow, token_endpoint_request): creds = self.make_credentials() req = requests.Request() + issue_time = datetime.datetime(2022, 1, 1, 0, 0, 0) + utcnow.return_value = issue_time + expires_in_seconds = 3599 + token_endpoint_request.return_value = { "access_token": "ais_token", - "expires_in": 3599, + "expires_in": expires_in_seconds, "issued_token_type": "urn:ietf:params:oauth:token-type:access_token", "token_type": "Bearer", } - utcnow.return_value = datetime.datetime(2022, 1, 1, 0, 0, 0) - k8s_token = "k8s_token" + k8s_token = self.K8S_TOKEN ais_token, ais_expiry = creds._make_ais_token_request(k8s_token, req) assert ais_token == "ais_token" - assert ais_expiry == datetime.datetime(2022, 1, 1, 0, 59, 59) + assert ais_expiry == issue_time + datetime.timedelta( + seconds=expires_in_seconds + ) token_endpoint_request.assert_called_with( req, creds._ais_token_endpoint, @@ -142,7 +148,7 @@ def test__make_ais_token_request(self, utcnow, token_endpoint_request): autospec=True, ) def test_refresh(self, ais_token_request, k8s_token_request): - k8s_token_request.return_value = "k8s_token" + k8s_token_request.return_value = self.K8S_TOKEN mock_expiry = mock.Mock() ais_token_request.return_value = ("ais_token", mock_expiry) @@ -151,7 +157,7 @@ def test_refresh(self, ais_token_request, k8s_token_request): creds.refresh(req) k8s_token_request.assert_called_with(creds, req) - ais_token_request.assert_called_with(creds, "k8s_token", req) + ais_token_request.assert_called_with(creds, self.K8S_TOKEN, req) assert creds.token == "ais_token" assert creds.expiry == mock_expiry diff --git a/tests/test__default.py b/tests/test__default.py index 93e93b8da..5177ce975 100644 --- a/tests/test__default.py +++ b/tests/test__default.py @@ -643,13 +643,17 @@ def test__get_gcloud_sdk_credentials_no_project_id(load, unused_isfile, get_proj def test__get_gdch_service_account_credentials_no_format_version(): with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: _default._get_gdch_service_account_credentials({}) - assert excinfo.match("format_version is not provided or unsupported") + assert excinfo.match( + "format_version is not provided or unsupported. Supported version is: v1" + ) def test__get_gdch_service_account_credentials_invalid_format_version(): with pytest.raises(exceptions.DefaultCredentialsError) as excinfo: _default._get_gdch_service_account_credentials({"format_version": "v2"}) - assert excinfo.match("format_version is not provided or unsupported") + assert excinfo.match( + "format_version is not provided or unsupported. Supported version is: v1" + ) class _AppIdentityModule(object): From e9106bfe9261032c3875a40aaa7c23cbb9ea2a87 Mon Sep 17 00:00:00 2001 From: Owl Bot Date: Tue, 10 May 2022 08:57:01 +0000 Subject: [PATCH 3/4] =?UTF-8?q?=F0=9F=A6=89=20Updates=20from=20OwlBot=20po?= =?UTF-8?q?st-processor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md --- tests/oauth2/test_gdch_credentials.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/oauth2/test_gdch_credentials.py b/tests/oauth2/test_gdch_credentials.py index 0754f044c..035e157e0 100644 --- a/tests/oauth2/test_gdch_credentials.py +++ b/tests/oauth2/test_gdch_credentials.py @@ -121,9 +121,7 @@ def test__make_ais_token_request(self, utcnow, token_endpoint_request): k8s_token = self.K8S_TOKEN ais_token, ais_expiry = creds._make_ais_token_request(k8s_token, req) assert ais_token == "ais_token" - assert ais_expiry == issue_time + datetime.timedelta( - seconds=expires_in_seconds - ) + assert ais_expiry == issue_time + datetime.timedelta(seconds=expires_in_seconds) token_endpoint_request.assert_called_with( req, creds._ais_token_endpoint, From e0b5f3b13d8de04d1e2d1ae098eadb0ca89651cd Mon Sep 17 00:00:00 2001 From: Sijun Liu Date: Tue, 10 May 2022 12:01:11 -0700 Subject: [PATCH 4/4] remove quota project id --- google/auth/_default.py | 5 ++++- google/oauth2/gdch_credentials.py | 21 +-------------------- tests/oauth2/test_gdch_credentials.py | 10 ---------- tests/test__default.py | 9 +++++++-- 4 files changed, 12 insertions(+), 33 deletions(-) diff --git a/google/auth/_default.py b/google/auth/_default.py index fbf9d944b..fd346b102 100644 --- a/google/auth/_default.py +++ b/google/auth/_default.py @@ -136,6 +136,8 @@ def load_credentials_from_file( def _load_credentials_from_info( filename, info, scopes, default_scopes, quota_project_id, request ): + from google.auth.credentials import CredentialsWithQuotaProject + credential_type = info.get("type") if credential_type == _AUTHORIZED_USER_TYPE: @@ -169,7 +171,8 @@ def _load_credentials_from_info( file=filename, type=credential_type, valid_types=_VALID_TYPES ) ) - credentials = _apply_quota_project_id(credentials, quota_project_id) + if isinstance(credentials, CredentialsWithQuotaProject): + credentials = _apply_quota_project_id(credentials, quota_project_id) return credentials, project_id diff --git a/google/oauth2/gdch_credentials.py b/google/oauth2/gdch_credentials.py index 0d9e08c38..e0edbf039 100644 --- a/google/oauth2/gdch_credentials.py +++ b/google/oauth2/gdch_credentials.py @@ -30,7 +30,7 @@ SERVICE_ACCOUNT_TOKEN_TYPE = "urn:k8s:params:oauth:token-type:serviceaccount" -class ServiceAccountCredentials(credentials.CredentialsWithQuotaProject): +class ServiceAccountCredentials(credentials.Credentials): """Credentials for GDCH (`Google Distributed Cloud Hosted`_) for service account users. @@ -89,7 +89,6 @@ def __init__( ais_ca_cert_path, ais_token_endpoint, audience, - quota_project_id=None, ): """ Args: @@ -107,9 +106,6 @@ def __init__( ais_token_endpoint (str): AIS token endpoint url audience (str): The audience for the requested AIS token. For example, it could be a k8s cluster or API service. - quota_project_id (Optional[str]): The project ID used for quota - and billing. This project may be different from the project - used to create the credentials. """ super(ServiceAccountCredentials, self).__init__() self._k8s_ca_cert_path = k8s_ca_cert_path @@ -119,7 +115,6 @@ def __init__( self._ais_ca_cert_path = ais_ca_cert_path self._ais_token_endpoint = ais_token_endpoint self._audience = audience - self._quota_project_id = quota_project_id def _make_k8s_token_request(self, request): k8s_request_body = { @@ -196,18 +191,4 @@ def with_audience(self, audience): self._ais_ca_cert_path, self._ais_token_endpoint, audience, - self._quota_project_id, - ) - - @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) - def with_quota_project(self, quota_project_id): - return self.__class__( - self._k8s_ca_cert_path, - self._k8s_cert_path, - self._k8s_key_path, - self._k8s_token_endpoint, - self._ais_ca_cert_path, - self._ais_token_endpoint, - self._audience, - quota_project_id, ) diff --git a/tests/oauth2/test_gdch_credentials.py b/tests/oauth2/test_gdch_credentials.py index 035e157e0..41aa399af 100644 --- a/tests/oauth2/test_gdch_credentials.py +++ b/tests/oauth2/test_gdch_credentials.py @@ -32,7 +32,6 @@ class TestCredentials(object): AIS_CA_CERT_PATH = "./ais_ca_cert.pem" AIS_TOKEN_ENDPOINT = "https://k8s_endpoint/v1/token" AUDIENCE = "audience_foo" - QUOTA_PROJECT = "project_foo" @classmethod def make_credentials(cls): @@ -44,7 +43,6 @@ def make_credentials(cls): cls.AIS_CA_CERT_PATH, cls.AIS_TOKEN_ENDPOINT, cls.AUDIENCE, - cls.QUOTA_PROJECT, ) def test_with_audience(self): @@ -54,13 +52,6 @@ def test_with_audience(self): new_creds = creds.with_audience("bar") assert new_creds._audience == "bar" - def test_with_quota_project(self): - creds = self.make_credentials() - assert creds.quota_project_id == self.QUOTA_PROJECT - - new_creds = creds.with_quota_project("project_bar") - assert new_creds._quota_project_id == "project_bar" - @mock.patch("google.oauth2._client._token_endpoint_request", autospec=True) def test__make_k8s_token_request(self, token_endpoint_request): creds = self.make_credentials() @@ -187,4 +178,3 @@ def test_before_request(self, ais_token_request, k8s_token_request): k8s_token_request.assert_called() ais_token_request.assert_called() assert headers["authorization"] == "Bearer ais_token" - assert headers["x-goog-user-project"] == "project_foo" diff --git a/tests/test__default.py b/tests/test__default.py index 5177ce975..ab8bad72e 100644 --- a/tests/test__default.py +++ b/tests/test__default.py @@ -1164,12 +1164,17 @@ def test_default_impersonated_service_account_set_both_scopes_and_default_scopes @mock.patch( "google.auth._cloud_sdk.get_application_default_credentials_path", autospec=True ) -def test_default_gdch_service_account_credentials(get_adc_path): +@mock.patch("google.auth._default._apply_quota_project_id", autospec=True) +def test_default_gdch_service_account_credentials(apply_quota_project_id, get_adc_path): get_adc_path.return_value = GDCH_SERVICE_ACCOUNT_FILE credentials, _ = _default.default(quota_project_id="project-foo") + + # make sure _apply_quota_project_id is not called since GDCH service account + # credential doesn't inheirt from CredentialsWithQuotaProject. + apply_quota_project_id.assert_not_called() + assert isinstance(credentials, gdch_credentials.ServiceAccountCredentials) - assert credentials._quota_project_id == "project-foo" assert credentials._k8s_ca_cert_path == "./k8s_ca_cert.pem" assert credentials._k8s_cert_path == "./k8s_cert.pem" assert credentials._k8s_key_path == "./k8s_key.pem"