Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion packages/bigframes/bigframes/_config/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@

from __future__ import annotations

import os
import threading
from typing import Optional

import google.auth.credentials
import google.auth.transport.requests
import pydata_google_auth

import bigframes._config.bigquery_options as bigquery_options
from bigframes._config import options

_SCOPES = ["https://www.googleapis.com/auth/cloud-platform"]

# Put the lock here rather than in BigQueryOptions so that BigQueryOptions
Expand All @@ -30,7 +34,33 @@
_cached_project_default: Optional[str] = None


def get_default_credentials_with_project() -> tuple[
_GOOGLE_CLOUD_PROJECT = "GOOGLE_CLOUD_PROJECT"


def resolve_credentials_and_project(
options: bigquery_options.BigQueryOptions,
) -> tuple[google.auth.credentials.Credentials, str]:
project = options.project
credentials = options.credentials
if project is None:
project = os.getenv(_GOOGLE_CLOUD_PROJECT)

if credentials is None:
credentials, cred_project = _get_default_credentials_with_project()
# This might conflict with explicit project, which will be ignored, credentials project
# only used if nothing else specified
if project is None:
project = cred_project

if project is None:
raise ValueError(
"Project must be set to initialize BigQuery client. "
"Try setting `bigframes.options.bigquery.project` first."
)
return credentials, project


def _get_default_credentials_with_project() -> tuple[
google.auth.credentials.Credentials, Optional[str]
]:
global _AUTH_LOCK, _cached_credentials, _cached_project_default
Expand Down
11 changes: 8 additions & 3 deletions packages/bigframes/bigframes/pandas/io/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,18 +655,23 @@ def _get_bqclient_and_project() -> Tuple[bigquery.Client, str]:
# Address circular imports in doctest due to bigframes/session/__init__.py
# containing a lot of logic and samples.
from bigframes.session import clients
import bigframes._config.auth

credentials, project = bigframes._config.auth.resolve_credentials_and_project(
config.options.bigquery
)

clients_provider = clients.ClientsProvider(
project=config.options.bigquery.project,
project=project,
location=config.options.bigquery.location,
use_regional_endpoints=config.options.bigquery.use_regional_endpoints,
credentials=config.options.bigquery.credentials,
credentials=credentials,
application_name=config.options.bigquery.application_name,
bq_kms_key_name=config.options.bigquery.kms_key_name,
client_endpoints_override=config.options.bigquery.client_endpoints_override,
requests_transport_adapters=config.options.bigquery.requests_transport_adapters,
)
return clients_provider.bqclient, clients_provider._project
return clients_provider.bqclient, project


def _dry_run(query, bqclient) -> bigquery.QueryJob:
Expand Down
9 changes: 7 additions & 2 deletions packages/bigframes/bigframes/session/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
)

import bigframes._config
import bigframes._config.auth
import bigframes._config.bigquery_options as bigquery_options
import bigframes.clients
import bigframes.constants
Expand Down Expand Up @@ -217,11 +218,15 @@ def __init__(
if clients_provider:
self._clients_provider = clients_provider
else:
credentials, project = (
bigframes._config.auth.resolve_credentials_and_project(context)
)

self._clients_provider = clients.ClientsProvider(
project=context.project,
project=project,
credentials=credentials,
location=self._location,
use_regional_endpoints=context.use_regional_endpoints,
credentials=context.credentials,
application_name=context.application_name,
bq_kms_key_name=self._bq_kms_key_name,
client_endpoints_override=context.client_endpoints_override,
Expand Down
30 changes: 2 additions & 28 deletions packages/bigframes/bigframes/session/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@
import google.cloud.storage # type: ignore
import requests

import bigframes._config.auth
import bigframes.constants
import bigframes.version

from . import environment

_ENV_DEFAULT_PROJECT = "GOOGLE_CLOUD_PROJECT"
_APPLICATION_NAME = f"bigframes/{bigframes.version.__version__} ibis/9.2.0"


Expand All @@ -50,10 +48,6 @@
_BIGQUERYSTORAGE_REGIONAL_ENDPOINT = "bigquerystorage.{location}.rep.googleapis.com"


def _get_default_credentials_with_project():
return bigframes._config.auth.get_default_credentials_with_project()


def _get_application_names():
apps = [_APPLICATION_NAME]

Expand All @@ -74,10 +68,10 @@ class ClientsProvider:

def __init__(
self,
project: Optional[str] = None,
project: str,
credentials: google.auth.credentials.Credentials,
location: Optional[str] = None,
use_regional_endpoints: Optional[bool] = None,
credentials: Optional[google.auth.credentials.Credentials] = None,
application_name: Optional[str] = None,
bq_kms_key_name: Optional[str] = None,
client_endpoints_override: dict = {},
Expand All @@ -86,26 +80,6 @@ def __init__(
Tuple[str, requests.adapters.BaseAdapter]
] = (),
):
credentials_project = None
if credentials is None:
credentials, credentials_project = _get_default_credentials_with_project()

# Prefer the project in this order:
# 1. Project explicitly specified by the user
# 2. Project set in the environment
# 3. Project associated with the default credentials
project = (
project
or os.getenv(_ENV_DEFAULT_PROJECT)
or typing.cast(Optional[str], credentials_project)
)

if not project:
raise ValueError(
"Project must be set to initialize BigQuery client. "
"Try setting `bigframes.options.bigquery.project` first."
)

self._application_name = (
f"{_get_application_names()} {application_name}"
if application_name
Expand Down
14 changes: 12 additions & 2 deletions packages/bigframes/tests/system/large/test_location.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import typing
import unittest.mock as mock

import google.auth.credentials
import pandas
import pandas.testing
import pytest
Expand Down Expand Up @@ -176,8 +178,12 @@ def test_bq_rep_endpoints(bigquery_location):


def test_clients_provider_no_location():
credentials = mock.create_autospec(google.auth.credentials.Credentials)

with pytest.raises(ValueError, match="Must set location to use regional endpoints"):
bigframes.session.clients.ClientsProvider(use_regional_endpoints=True)
bigframes.session.clients.ClientsProvider(
project="", credentials=credentials, use_regional_endpoints=True
)


@pytest.mark.parametrize(
Expand All @@ -186,12 +192,16 @@ def test_clients_provider_no_location():
sorted(bigframes.constants.REP_NOT_ENABLED_BIGQUERY_LOCATIONS),
)
def test_clients_provider_use_regional_endpoints_non_rep_locations(bigquery_location):
credentials = mock.create_autospec(google.auth.credentials.Credentials)
with pytest.raises(
ValueError,
match=f"not .*available in the location {bigquery_location}",
):
bigframes.session.clients.ClientsProvider(
location=bigquery_location, use_regional_endpoints=True
project="",
credentials=credentials,
location=bigquery_location,
use_regional_endpoints=True,
)


Expand Down
Loading