diff --git a/google/cloud/documentai_toolbox/utilities/gcs_utilities.py b/google/cloud/documentai_toolbox/utilities/gcs_utilities.py index 8583a779..0c5ff3e2 100644 --- a/google/cloud/documentai_toolbox/utilities/gcs_utilities.py +++ b/google/cloud/documentai_toolbox/utilities/gcs_utilities.py @@ -136,17 +136,13 @@ def get_blob( module (Optional[str]): Optional. The module for a custom user agent header. Returns: - List[storage.blob.Blob]: - A list of the blobs in the Cloud Storage path. + storage.blob.Blob: + The blob in the Cloud Storage path. """ - gcs_bucket_name, gcs_file_name = split_gcs_uri(gcs_uri) - - if not re.match(constants.FILE_CHECK_REGEX, gcs_file_name): + if not re.match(constants.FILE_CHECK_REGEX, gcs_uri): raise ValueError("gcs_uri must link to a single file.") - storage_client = _get_storage_client(module=module) - bucket = storage_client.bucket(bucket_name=gcs_bucket_name) - return bucket.get_blob(gcs_file_name) + return storage.Blob.from_string(gcs_uri, _get_storage_client(module=module)) def split_gcs_uri(gcs_uri: str) -> Tuple[str, str]: diff --git a/google/cloud/documentai_toolbox/wrappers/document.py b/google/cloud/documentai_toolbox/wrappers/document.py index 6df97312..6a49ed49 100644 --- a/google/cloud/documentai_toolbox/wrappers/document.py +++ b/google/cloud/documentai_toolbox/wrappers/document.py @@ -366,6 +366,7 @@ class Document: shards: List[documentai.Document] = dataclasses.field(repr=False) gcs_bucket_name: Optional[str] = dataclasses.field(default=None, repr=False) gcs_prefix: Optional[str] = dataclasses.field(default=None, repr=False) + gcs_uri: Optional[str] = dataclasses.field(default=None, repr=False) gcs_input_uri: Optional[str] = dataclasses.field(default=None, repr=False) _pages: Optional[List[Page]] = dataclasses.field( @@ -463,7 +464,7 @@ def from_gcs( gcs_prefix: str, gcs_input_uri: Optional[str] = None, ) -> "Document": - r"""Loads Document from Cloud Storage. + r"""Loads a Document from a Cloud Storage directory. Args: gcs_bucket_name (str): @@ -490,6 +491,40 @@ def from_gcs( gcs_input_uri=gcs_input_uri, ) + @classmethod + def from_gcs_uri( + cls: Type["Document"], + gcs_uri: str, + gcs_input_uri: Optional[str] = None, + ) -> "Document": + r"""Loads a Document from a Cloud Storage uri. + + Args: + gcs_uri (str): + Required. The full GCS uri to a Document JSON file. + + Example: `gs://{bucket_name}/{optional_folder}/{target_file}.json`. + gcs_input_uri (str): + Optional. The gcs uri to the original input file. + + Format: `gs://{bucket_name}/{optional_folder}/{target_folder}/{file_name}.pdf` + Returns: + Document: + A document from gcs. + """ + blob = gcs_utilities.get_blob(gcs_uri=gcs_uri, module="get-document") + shards = [ + documentai.Document.from_json( + blob.download_as_bytes(), + ignore_unknown_fields=True, + ) + ] + return cls( + shards=shards, + gcs_uri=gcs_uri, + gcs_input_uri=gcs_input_uri, + ) + @classmethod def from_batch_process_metadata( cls: Type["Document"], metadata: documentai.BatchProcessMetadata diff --git a/tests/unit/test_document.py b/tests/unit/test_document.py index 31ac799e..19c36ca4 100644 --- a/tests/unit/test_document.py +++ b/tests/unit/test_document.py @@ -105,6 +105,17 @@ def get_bytes_missing_shard_mock(): yield byte_factory +@pytest.fixture +def get_blob_mock(): + with mock.patch.object(gcs_utilities, "get_blob") as blob_factory: + mock_blob = mock.Mock() + mock_blob.download_as_bytes.return_value = get_bytes("tests/unit/resources/0")[ + 0 + ] + blob_factory.return_value = mock_blob + yield blob_factory + + def create_document_with_images_without_bbox(get_bytes_images_mock): doc = document.Document.from_gcs( gcs_bucket_name="test-directory", gcs_prefix="documentai/output/123456789/0" @@ -394,6 +405,25 @@ def test_document_from_gcs_with_unordered_shards(get_bytes_unordered_files_mock) assert page.page_number == page_index + 1 +def test_document_from_gcs_uri(get_blob_mock): + actual = document.Document.from_gcs_uri( + gcs_uri="gs://test-directory/documentai/output/123456789/0/document.json" + ) + + get_blob_mock.assert_called_once() + + assert ( + actual.gcs_uri + == "gs://test-directory/documentai/output/123456789/0/document.json" + ) + assert len(actual.pages) == 1 + # checking cached value + assert len(actual.pages) == 1 + + assert len(actual.text) > 0 + assert len(actual.text) > 0 + + def test_document_from_batch_process_metadata_with_multiple_input_files( get_bytes_multiple_directories_mock, ):