Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion google/cloud/aiplatform/utils/source_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,11 @@ def make_package(self, package_directory: str) -> str:
fp.write(setup_py_output)

if os.path.isdir(self.script_path):
shutil.copytree(self.script_path, trainer_path, dirs_exist_ok=True)
# Remove destination path if it already exists
shutil.rmtree(trainer_path)

# Copy folder recursively
shutil.copytree(src=self.script_path, dst=trainer_path)
else:
# The module that will contain the script
script_out_path = trainer_path / f"{self.task_module_name}.py"
Expand Down
77 changes: 77 additions & 0 deletions tests/unit/aiplatform/test_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
#

from importlib import reload
import filecmp
import json
import os
import pytest
import tempfile

from google.cloud.aiplatform.training_utils import environment_variables
from google.cloud.aiplatform.utils import source_utils
from unittest import mock

_TEST_TRAINING_DATA_URI = "gs://training-data-uri"
Expand Down Expand Up @@ -203,3 +206,77 @@ def test_http_handler_port(self):
def test_http_handler_port_none(self):
reload(environment_variables)
assert environment_variables.http_handler_port is None

@pytest.fixture()
def mock_temp_file_name(self):
# Create random files
# tmpdirname = tempfile.TemporaryDirectory()
file = tempfile.NamedTemporaryFile()

with open(file.name, "w") as handle:
handle.write("test")

yield file.name

file.close()

def test_package_file(self, mock_temp_file_name):
# Test that the packager properly copies the source file to the destination file

packager = source_utils._TrainingScriptPythonPackager(
script_path=mock_temp_file_name
)

with tempfile.TemporaryDirectory() as destination_directory_name:
_ = packager.make_package(package_directory=destination_directory_name)

# Check that contents of source_distribution_path is the same as destination_directory_name
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}/{packager.task_module_name}.py"

assert filecmp.cmp(
mock_temp_file_name, destination_inner_path, shallow=False
)

@pytest.fixture()
def mock_temp_folder_name(self):
# Create random folder
folder = tempfile.TemporaryDirectory()

file = tempfile.NamedTemporaryFile(dir=folder.name)

# Create random file in the folder
with open(file.name, "w") as handle:
handle.write("test")

yield folder.name

file.close()

folder.cleanup()

def test_package_folder(self, mock_temp_folder_name):
# Test that the packager properly copies the source folder to the destination folder

packager = source_utils._TrainingScriptPythonPackager(
script_path=mock_temp_folder_name
)

with tempfile.TemporaryDirectory() as destination_directory_name:
# Add an existing file into the destination directory to check if it gets deleted
existing_file = tempfile.NamedTemporaryFile(dir=destination_directory_name)

with open(existing_file.name, "w") as handle:
handle.write("existing")

_ = packager.make_package(package_directory=destination_directory_name)

# Check that contents of source_distribution_path is the same as destination_directory_name
destination_inner_path = f"{destination_directory_name}/{packager._TRAINER_FOLDER}/{packager._ROOT_MODULE}"

dcmp = filecmp.dircmp(mock_temp_folder_name, destination_inner_path)

assert len(dcmp.diff_files) == 0
assert len(dcmp.left_only) == 0
assert len(dcmp.right_only) == 0

existing_file.close()