Skip to content

Commit 59f2f0a

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Gen AI SDK client - Fix bug in GCS bucket creation for new agent engines.
PiperOrigin-RevId: 834146214
1 parent bc26160 commit 59f2f0a

File tree

3 files changed

+157
-1
lines changed

3 files changed

+157
-1
lines changed

tests/unit/vertexai/genai/test_agent_engines.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,62 @@ def test_create_base64_encoded_tarball_outside_project_dir_raises(self):
14501450
finally:
14511451
os.chdir(origin_dir)
14521452

1453+
@mock.patch.object(_agent_engines_utils, "_upload_requirements")
1454+
@mock.patch.object(_agent_engines_utils, "_upload_extra_packages")
1455+
@mock.patch.object(_agent_engines_utils, "_upload_agent_engine")
1456+
@mock.patch.object(_agent_engines_utils, "_scan_requirements")
1457+
@mock.patch.object(_agent_engines_utils, "_get_gcs_bucket")
1458+
def test_prepare_with_creds(self, mock_get_gcs_bucket, mock_scan_requirements, mock_upload_agent_engine, mock_upload_extra_packages, mock_upload_requirements):
1459+
mock_scan_requirements.return_value = {}
1460+
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
1461+
mock_creds.universe_domain = "googleapis.com"
1462+
_agent_engines_utils._prepare(
1463+
agent=self.test_agent,
1464+
project=_TEST_PROJECT,
1465+
location=_TEST_LOCATION,
1466+
staging_bucket=_TEST_STAGING_BUCKET,
1467+
credentials=mock_creds,
1468+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1469+
requirements=[],
1470+
extra_packages=[],
1471+
)
1472+
mock_upload_agent_engine.assert_called_once_with(
1473+
agent=self.test_agent,
1474+
gcs_bucket=mock.ANY,
1475+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1476+
)
1477+
1478+
@mock.patch.object(_agent_engines_utils, "_upload_requirements")
1479+
@mock.patch.object(_agent_engines_utils, "_upload_extra_packages")
1480+
@mock.patch.object(_agent_engines_utils, "_upload_agent_engine")
1481+
@mock.patch.object(_agent_engines_utils, "_scan_requirements")
1482+
@mock.patch("google.auth.default")
1483+
@mock.patch.object(_agent_engines_utils, "_get_gcs_bucket")
1484+
def test_prepare_without_creds(self, mock_get_gcs_bucket, mock_auth_default, mock_scan_requirements, mock_upload_agent_engine, mock_upload_extra_packages, mock_upload_requirements):
1485+
mock_scan_requirements.return_value = {}
1486+
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
1487+
mock_auth_default.return_value = (mock_creds, _TEST_PROJECT)
1488+
_agent_engines_utils._prepare(
1489+
agent=self.test_agent,
1490+
project=_TEST_PROJECT,
1491+
location=_TEST_LOCATION,
1492+
staging_bucket=_TEST_STAGING_BUCKET,
1493+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1494+
requirements=[],
1495+
extra_packages=[],
1496+
)
1497+
mock_get_gcs_bucket.assert_called_once_with(
1498+
project=_TEST_PROJECT,
1499+
location=_TEST_LOCATION,
1500+
staging_bucket=_TEST_STAGING_BUCKET,
1501+
credentials=None,
1502+
)
1503+
mock_upload_agent_engine.assert_called_once_with(
1504+
agent=self.test_agent,
1505+
gcs_bucket=mock.ANY,
1506+
gcs_dir_name=_TEST_GCS_DIR_NAME,
1507+
)
1508+
14531509

14541510
@pytest.mark.usefixtures("google_auth_mock")
14551511
class TestAgentEngine:
@@ -2623,6 +2679,101 @@ def test_operation_schemas(
26232679
assert test_agent_engine.operation_schemas() == want_operation_schemas
26242680

26252681

2682+
@mock.patch.object(_agent_engines_utils, "_prepare")
2683+
@mock.patch.object(agent_engines.AgentEngines, "_create")
2684+
@mock.patch.object(_agent_engines_utils, "_await_operation")
2685+
def test_create_agent_engine_with_creds(self, mock_await_operation, mock_create, mock_prepare):
2686+
mock_operation = mock.Mock()
2687+
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
2688+
mock_create.return_value = mock_operation
2689+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
2690+
response=_genai_types.ReasoningEngine(
2691+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2692+
spec=_TEST_AGENT_ENGINE_SPEC,
2693+
)
2694+
)
2695+
self.client.agent_engines.create(
2696+
agent=self.test_agent,
2697+
config=_genai_types.AgentEngineConfig(
2698+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
2699+
staging_bucket=_TEST_STAGING_BUCKET,
2700+
),
2701+
)
2702+
mock_args, mock_kwargs = mock_prepare.call_args
2703+
assert mock_kwargs['agent'] == self.test_agent
2704+
assert mock_kwargs['extra_packages'] == []
2705+
assert mock_kwargs['project'] == _TEST_PROJECT
2706+
assert mock_kwargs['location'] == _TEST_LOCATION
2707+
assert mock_kwargs['staging_bucket'] == _TEST_STAGING_BUCKET
2708+
assert mock_kwargs['credentials'] == _TEST_CREDENTIALS
2709+
assert mock_kwargs['gcs_dir_name'] == 'agent_engine'
2710+
2711+
@mock.patch.object(_agent_engines_utils, "_prepare")
2712+
@mock.patch.object(agent_engines.AgentEngines, "_create")
2713+
@mock.patch("google.auth.default")
2714+
@mock.patch.object(_agent_engines_utils, "_await_operation")
2715+
def test_create_agent_engine_without_creds(self, mock_await_operation, mock_auth_default, mock_create, mock_prepare):
2716+
mock_operation = mock.Mock()
2717+
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
2718+
mock_create.return_value = mock_operation
2719+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
2720+
response=_genai_types.ReasoningEngine(
2721+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2722+
spec=_TEST_AGENT_ENGINE_SPEC,
2723+
)
2724+
)
2725+
mock_creds = mock.Mock(spec=auth_credentials.AnonymousCredentials())
2726+
mock_creds.quota_project_id = _TEST_PROJECT
2727+
mock_auth_default.return_value = (mock_creds, _TEST_PROJECT)
2728+
client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=mock_creds)
2729+
client.agent_engines.create(
2730+
agent=self.test_agent,
2731+
config=_genai_types.AgentEngineConfig(
2732+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
2733+
staging_bucket=_TEST_STAGING_BUCKET,
2734+
),
2735+
)
2736+
mock_args, mock_kwargs = mock_prepare.call_args
2737+
assert mock_kwargs['agent'] == self.test_agent
2738+
assert mock_kwargs['extra_packages'] == []
2739+
assert mock_kwargs['project'] == _TEST_PROJECT
2740+
assert mock_kwargs['location'] == _TEST_LOCATION
2741+
assert mock_kwargs['staging_bucket'] == _TEST_STAGING_BUCKET
2742+
assert mock_kwargs['credentials'] == mock_creds
2743+
assert mock_kwargs['gcs_dir_name'] == 'agent_engine'
2744+
2745+
@mock.patch.object(_agent_engines_utils, "_prepare")
2746+
@mock.patch.object(agent_engines.AgentEngines, "_create")
2747+
@mock.patch.object(_agent_engines_utils, "_await_operation")
2748+
def test_create_agent_engine_with_no_creds_in_client(self, mock_await_operation, mock_create, mock_prepare):
2749+
mock_operation = mock.Mock()
2750+
mock_operation.name = _TEST_AGENT_ENGINE_OPERATION_NAME
2751+
mock_create.return_value = mock_operation
2752+
mock_await_operation.return_value = _genai_types.AgentEngineOperation(
2753+
response=_genai_types.ReasoningEngine(
2754+
name=_TEST_AGENT_ENGINE_RESOURCE_NAME,
2755+
spec=_TEST_AGENT_ENGINE_SPEC,
2756+
)
2757+
)
2758+
client = vertexai.Client(project=_TEST_PROJECT, location=_TEST_LOCATION, credentials=None)
2759+
client.agent_engines.create(
2760+
agent=self.test_agent,
2761+
config=_genai_types.AgentEngineConfig(
2762+
display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME,
2763+
staging_bucket=_TEST_STAGING_BUCKET,
2764+
),
2765+
)
2766+
mock_args, mock_kwargs = mock_prepare.call_args
2767+
assert mock_kwargs['agent'] == self.test_agent
2768+
assert mock_kwargs['extra_packages'] == []
2769+
assert mock_kwargs['project'] == _TEST_PROJECT
2770+
assert mock_kwargs['location'] == _TEST_LOCATION
2771+
assert mock_kwargs['staging_bucket'] == _TEST_STAGING_BUCKET
2772+
assert mock_kwargs['credentials'] == None
2773+
assert mock_kwargs['gcs_dir_name'] == 'agent_engine'
2774+
2775+
2776+
26262777
@pytest.mark.usefixtures("google_auth_mock")
26272778
class TestAgentEngineErrors:
26282779
def setup_method(self):

vertexai/_genai/_agent_engines_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,10 +772,11 @@ def _get_gcs_bucket(
772772
project: str,
773773
location: str,
774774
staging_bucket: str,
775+
credentials: Optional[Any] = None,
775776
) -> _StorageBucket:
776777
"""Gets or creates the GCS bucket."""
777778
storage = _import_cloud_storage_or_raise()
778-
storage_client = storage.Client(project=project)
779+
storage_client = storage.Client(project=project, credentials=credentials)
779780
staging_bucket = staging_bucket.replace("gs://", "")
780781
try:
781782
gcs_bucket = storage_client.get_bucket(staging_bucket)
@@ -910,6 +911,7 @@ def _prepare(
910911
location: str,
911912
staging_bucket: str,
912913
gcs_dir_name: str,
914+
credentials: Optional[Any] = None,
913915
) -> None:
914916
"""Prepares the agent engine for creation or updates in Vertex AI.
915917
@@ -928,13 +930,15 @@ def _prepare(
928930
staging_bucket (str): The staging bucket name in the form "gs://...".
929931
gcs_dir_name (str): The GCS bucket directory under `staging_bucket` to
930932
use for staging the artifacts needed.
933+
credentials: The credentials to use for the storage client.
931934
"""
932935
if agent is None:
933936
return
934937
gcs_bucket = _get_gcs_bucket(
935938
project=project,
936939
location=location,
937940
staging_bucket=staging_bucket,
941+
credentials=credentials,
938942
)
939943
_upload_agent_engine(
940944
agent=agent,

vertexai/_genai/agent_engines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1082,6 +1082,7 @@ def _create_config(
10821082
staging_bucket=staging_bucket,
10831083
gcs_dir_name=gcs_dir_name,
10841084
extra_packages=extra_packages,
1085+
credentials=self._api_client._credentials,
10851086
)
10861087
# Update the package spec.
10871088
update_masks.append("spec.package_spec.pickle_object_gcs_uri")

0 commit comments

Comments
 (0)