Skip to content

Commit 7d164a4

Browse files
BenjaminKazemicopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Support Assemble feature on the multimodal datasets.
PiperOrigin-RevId: 824635300
1 parent 6737a70 commit 7d164a4

File tree

7 files changed

+728
-174
lines changed

7 files changed

+728
-174
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
16+
17+
from tests.unit.vertexai.genai.replays import pytest_helper
18+
from vertexai._genai import types
19+
20+
import pytest
21+
22+
METADATA_SCHEMA_URI = (
23+
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
24+
)
25+
BIGQUERY_TABLE_NAME = "vertex-sdk-dev.multimodal_dataset.test-table"
26+
DATASET = "8810841321427173376"
27+
28+
29+
def test_assemple_dataset(client):
30+
operation = client.datasets._assemble_multimodal_dataset(
31+
name=DATASET,
32+
gemini_request_read_config={
33+
"template_config": {
34+
"field_mapping": {"question": "questionColumn"},
35+
},
36+
},
37+
)
38+
assert isinstance(operation, types.MultimodalDatasetOperation)
39+
40+
41+
def test_assemple_dataset_public(client):
42+
assemble_dataset = client.datasets.assemble(
43+
name=DATASET,
44+
template_config=types.GeminiTemplateConfig(
45+
gemini_example=types.GeminiExample(
46+
model="gemini-1.5-flash",
47+
contents=[
48+
{
49+
"role": "user",
50+
"parts": [{"text": "What is the capital of {name}?"}],
51+
}
52+
],
53+
),
54+
),
55+
)
56+
assert isinstance(assemble_dataset, types.AssembleDataset)
57+
assert assemble_dataset.bigquery_destination.startswith(
58+
f"bq://{BIGQUERY_TABLE_NAME}"
59+
)
60+
61+
62+
pytestmark = pytest_helper.setup(
63+
file=__file__,
64+
globals_for_file=globals(),
65+
)
66+
67+
pytest_plugins = ("pytest_asyncio",)
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_assemple_dataset_async(client):
72+
operation = await client.aio.datasets._assemble_multimodal_dataset(
73+
name=DATASET,
74+
gemini_request_read_config={
75+
"template_config": {
76+
"field_mapping": {"question": "questionColumn"},
77+
},
78+
},
79+
)
80+
assert isinstance(operation, types.MultimodalDatasetOperation)
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_assemple_dataset_public_async(client):
85+
assemble_dataset = await client.aio.datasets.assemble(
86+
name=DATASET,
87+
template_config=types.GeminiTemplateConfig(
88+
gemini_example=types.GeminiExample(
89+
model="gemini-1.5-flash",
90+
contents=[
91+
{
92+
"role": "user",
93+
"parts": [{"text": "What is the capital of {name}?"}],
94+
}
95+
],
96+
),
97+
),
98+
)
99+
assert isinstance(assemble_dataset, types.AssembleDataset)
100+
assert assemble_dataset.bigquery_destination.startswith(
101+
f"bq://{BIGQUERY_TABLE_NAME}"
102+
)

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ def test_create_dataset_from_bigquery(client):
5454
)
5555
assert isinstance(dataset, types.MultimodalDataset)
5656
assert dataset.display_name == "test-from-bigquery"
57+
assert dataset.metadata.input_config.bigquery_source.uri == (
58+
f"bq://{BIGQUERY_TABLE_NAME}"
59+
)
5760

5861

5962
def test_create_dataset_from_bigquery_without_bq_prefix(client):
@@ -70,6 +73,9 @@ def test_create_dataset_from_bigquery_without_bq_prefix(client):
7073
)
7174
assert isinstance(dataset, types.MultimodalDataset)
7275
assert dataset.display_name == "test-from-bigquery"
76+
assert dataset.metadata.input_config.bigquery_source.uri == (
77+
f"bq://{BIGQUERY_TABLE_NAME}"
78+
)
7379

7480

7581
pytestmark = pytest_helper.setup(
@@ -111,6 +117,9 @@ async def test_create_dataset_from_bigquery_async(client):
111117
)
112118
assert isinstance(dataset, types.MultimodalDataset)
113119
assert dataset.display_name == "test-from-bigquery"
120+
assert dataset.metadata.input_config.bigquery_source.uri == (
121+
f"bq://{BIGQUERY_TABLE_NAME}"
122+
)
114123

115124

116125
@pytest.mark.asyncio
@@ -129,6 +138,9 @@ async def test_create_dataset_from_bigquery_async_with_timeout(client):
129138
)
130139
assert isinstance(dataset, types.MultimodalDataset)
131140
assert dataset.display_name == "test-from-bigquery"
141+
assert dataset.metadata.input_config.bigquery_source.uri == (
142+
f"bq://{BIGQUERY_TABLE_NAME}"
143+
)
132144

133145

134146
@pytest.mark.asyncio
@@ -146,3 +158,6 @@ async def test_create_dataset_from_bigquery_async_without_bq_prefix(client):
146158
)
147159
assert isinstance(dataset, types.MultimodalDataset)
148160
assert dataset.display_name == "test-from-bigquery"
161+
assert dataset.metadata.input_config.bigquery_source.uri == (
162+
f"bq://{BIGQUERY_TABLE_NAME}"
163+
)

tests/unit/vertexai/genai/test_evals.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,9 +1070,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
10701070
)
10711071

10721072
mock_agent_engine = mock.Mock()
1073-
mock_agent_engine.async_create_session = mock.AsyncMock(
1074-
return_value={"id": "session1"}
1075-
)
1073+
mock_agent_engine.create_session.return_value = {"id": "session1"}
10761074
stream_query_return_value = [
10771075
{
10781076
"id": "1",
@@ -1088,13 +1086,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_dict(
10881086
},
10891087
]
10901088

1091-
async def _async_iterator(iterable):
1092-
for item in iterable:
1093-
yield item
1094-
1095-
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1096-
stream_query_return_value
1097-
)
1089+
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
10981090
mock_vertexai_client.return_value.agent_engines.get.return_value = (
10991091
mock_agent_engine
11001092
)
@@ -1108,10 +1100,10 @@ async def _async_iterator(iterable):
11081100
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
11091101
name="projects/test-project/locations/us-central1/reasoningEngines/123"
11101102
)
1111-
mock_agent_engine.async_create_session.assert_called_once_with(
1103+
mock_agent_engine.create_session.assert_called_once_with(
11121104
user_id="123", state={"a": "1"}
11131105
)
1114-
mock_agent_engine.async_stream_query.assert_called_once_with(
1106+
mock_agent_engine.stream_query.assert_called_once_with(
11151107
user_id="123", session_id="session1", message="agent prompt"
11161108
)
11171109

@@ -1162,9 +1154,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
11621154
)
11631155

11641156
mock_agent_engine = mock.Mock()
1165-
mock_agent_engine.async_create_session = mock.AsyncMock(
1166-
return_value={"id": "session1"}
1167-
)
1157+
mock_agent_engine.create_session.return_value = {"id": "session1"}
11681158
stream_query_return_value = [
11691159
{
11701160
"id": "1",
@@ -1180,13 +1170,7 @@ def test_run_inference_with_agent_engine_and_session_inputs_literal_string(
11801170
},
11811171
]
11821172

1183-
async def _async_iterator(iterable):
1184-
for item in iterable:
1185-
yield item
1186-
1187-
mock_agent_engine.async_stream_query.return_value = _async_iterator(
1188-
stream_query_return_value
1189-
)
1173+
mock_agent_engine.stream_query.return_value = iter(stream_query_return_value)
11901174
mock_vertexai_client.return_value.agent_engines.get.return_value = (
11911175
mock_agent_engine
11921176
)
@@ -1200,10 +1184,10 @@ async def _async_iterator(iterable):
12001184
mock_vertexai_client.return_value.agent_engines.get.assert_called_once_with(
12011185
name="projects/test-project/locations/us-central1/reasoningEngines/123"
12021186
)
1203-
mock_agent_engine.async_create_session.assert_called_once_with(
1187+
mock_agent_engine.create_session.assert_called_once_with(
12041188
user_id="123", state={"a": "1"}
12051189
)
1206-
mock_agent_engine.async_stream_query.assert_called_once_with(
1190+
mock_agent_engine.stream_query.assert_called_once_with(
12071191
user_id="123", session_id="session1", message="agent prompt"
12081192
)
12091193

vertexai/_genai/_evals_common.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -278,12 +278,10 @@ def agent_run_wrapper(
278278
and type(agent_engine).__name__ == "AgentEngine"
279279
):
280280
agent_engine_instance = agent_engine
281-
return asyncio.run(
282-
inference_fn_arg(
283-
row=row_arg,
284-
contents=contents_arg,
285-
agent_engine=agent_engine_instance,
286-
)
281+
return inference_fn_arg(
282+
row=row_arg,
283+
contents=contents_arg,
284+
agent_engine=agent_engine_instance,
287285
)
288286

289287
future = executor.submit(
@@ -1265,7 +1263,7 @@ def _run_agent(
12651263
)
12661264

12671265

1268-
async def _execute_agent_run_with_retry(
1266+
def _execute_agent_run_with_retry(
12691267
row: pd.Series,
12701268
contents: Union[genai_types.ContentListUnion, genai_types.ContentListUnionDict],
12711269
agent_engine: types.AgentEngine,
@@ -1287,7 +1285,7 @@ async def _execute_agent_run_with_retry(
12871285
)
12881286
user_id = session_inputs.user_id
12891287
session_state = session_inputs.state
1290-
session = await agent_engine.async_create_session(
1288+
session = agent_engine.create_session(
12911289
user_id=user_id,
12921290
state=session_state,
12931291
)
@@ -1298,7 +1296,7 @@ async def _execute_agent_run_with_retry(
12981296
for attempt in range(max_retries):
12991297
try:
13001298
responses = []
1301-
async for event in agent_engine.async_stream_query(
1299+
for event in agent_engine.stream_query(
13021300
user_id=user_id,
13031301
session_id=session["id"],
13041302
message=contents,
@@ -1317,7 +1315,7 @@ async def _execute_agent_run_with_retry(
13171315
)
13181316
if attempt == max_retries - 1:
13191317
return {"error": f"Resource exhausted after retries: {e}"}
1320-
await asyncio.sleep(2**attempt)
1318+
time.sleep(2**attempt)
13211319
except Exception as e: # pylint: disable=broad-exception-caught
13221320
logger.error(
13231321
"Unexpected error during generate_content on attempt %d/%d: %s",
@@ -1328,7 +1326,7 @@ async def _execute_agent_run_with_retry(
13281326

13291327
if attempt == max_retries - 1:
13301328
return {"error": f"Failed after retries: {e}"}
1331-
await asyncio.sleep(1)
1329+
time.sleep(1)
13321330
return {"error": f"Failed to get agent run results after {max_retries} retries"}
13331331

13341332

0 commit comments

Comments
 (0)