diff --git a/temporalio/testing/_workflow.py b/temporalio/testing/_workflow.py index 9966df76f..979222dea 100644 --- a/temporalio/testing/_workflow.py +++ b/temporalio/testing/_workflow.py @@ -15,6 +15,8 @@ import google.protobuf.empty_pb2 from typing_extensions import Self +import temporalio.api.nexus.v1 +import temporalio.api.operatorservice.v1 import temporalio.api.testservice.v1 import temporalio.bridge.testing import temporalio.client @@ -401,6 +403,48 @@ def supports_time_skipping(self) -> bool: """Whether this environment supports time skipping.""" return False + async def create_nexus_endpoint( + self, endpoint_name: str, task_queue: str + ) -> temporalio.api.nexus.v1.Endpoint: + """Create a Nexus endpoint with the given name and task queue. + + Args: + endpoint_name: The name of the Nexus endpoint to create. + task_queue: The task queue to associate with the endpoint. + + Returns: + The created Nexus endpoint. + """ + response = await self._client.operator_service.create_nexus_endpoint( + temporalio.api.operatorservice.v1.CreateNexusEndpointRequest( + spec=temporalio.api.nexus.v1.EndpointSpec( + name=endpoint_name, + target=temporalio.api.nexus.v1.EndpointTarget( + worker=temporalio.api.nexus.v1.EndpointTarget.Worker( + namespace=self._client.namespace, + task_queue=task_queue, + ) + ), + ) + ) + ) + return response.endpoint + + async def delete_nexus_endpoint( + self, endpoint: temporalio.api.nexus.v1.Endpoint + ) -> None: + """Delete a Nexus endpoint. + + Args: + endpoint: The Nexus endpoint to delete. + """ + await self._client.operator_service.delete_nexus_endpoint( + temporalio.api.operatorservice.v1.DeleteNexusEndpointRequest( + id=endpoint.id, + version=endpoint.version, + ) + ) + @contextmanager def auto_time_skipping_disabled(self) -> Iterator[None]: """Disable any automatic time skipping if this is a time-skipping diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 3eaf8de29..c6cb6f2e9 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -68,7 +68,7 @@ ) from tests.helpers import find_free_port, new_worker from tests.helpers.metrics import PromMetricMatcher -from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name +from tests.helpers.nexus import make_nexus_endpoint_name # TODO(nexus-preview): test worker shutdown, wait_all_completed, drain etc @@ -603,7 +603,8 @@ async def test_sync_operation_happy_path(client: Client, env: WorkflowEnvironmen task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_output = await client.execute_workflow( CallerWorkflow.run, args=[ @@ -640,7 +641,8 @@ async def test_workflow_run_operation_happy_path( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_output = await client.execute_workflow( CallerWorkflow.run, args=[ @@ -930,7 +932,8 @@ async def test_sync_response( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) caller_wf_handle = await client.start_workflow( CallerWorkflow.run, args=[ @@ -1004,6 +1007,7 @@ async def test_async_response( workflow_failure_exception_types=[Exception], ): caller_wf_handle, handler_wf_handle = await _start_wf_and_nexus_op( + env, client, task_queue, exception_in_operation_start, @@ -1076,6 +1080,7 @@ async def test_async_response( async def _start_wf_and_nexus_op( + env: WorkflowEnvironment, client: Client, task_queue: str, exception_in_operation_start: bool, @@ -1089,7 +1094,8 @@ async def _start_wf_and_nexus_op( """ Start the caller workflow and wait until the Nexus operation has started. """ - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) operation_workflow_id = str(uuid.uuid4()) # Start the caller workflow and wait until it confirms the Nexus operation has started. @@ -1174,7 +1180,8 @@ async def test_untyped_caller( op_definition_type=op_definition_type, exception_in_operation_start=exception_in_operation_start, ) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) caller_wf_handle = await client.start_workflow( UntypedCallerWorkflow.run, args=[ @@ -1335,7 +1342,8 @@ async def test_service_interface_and_implementation_names( task_queue=task_queue, workflow_failure_exception_types=[Exception], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) assert await client.execute_workflow( ServiceInterfaceAndImplCallerWorkflow.run, args=(CallerReference.INTERFACE, NameOverride.YES, task_queue), @@ -1451,7 +1459,8 @@ async def test_workflow_run_operation_can_execute_workflow_before_starting_backi ], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) result = await client.execute_workflow( WorkflowCallingNexusOperationThatExecutesWorkflowBeforeStartingBackingWorkflow.run, args=("result-1", task_queue), @@ -1503,7 +1512,8 @@ async def test_nexus_operation_summary( ], task_queue=task_queue, ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) wf_id = f"wf-{uuid.uuid4()}" handle = await client.start_workflow( ExecuteNexusOperationWithSummaryWorkflow.run, @@ -1799,7 +1809,8 @@ async def test_workflow_run_operation_overloads( ], nexus_service_handlers=[OverloadTestServiceHandler()], ): - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) res = await client.execute_workflow( OverloadTestCallerWorkflow.run, args=[op, OverloadTestValue(value=2)], @@ -1859,7 +1870,8 @@ async def test_workflow_caller_custom_metrics(client: Client, env: WorkflowEnvir pytest.skip("Nexus tests don't work with time-skipping server") task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) # Create new runtime with Prom server prom_addr = f"127.0.0.1:{find_free_port()}" @@ -1952,7 +1964,8 @@ async def test_workflow_caller_buffered_metrics( runtime=runtime, ) task_queue = str(uuid.uuid4()) - await create_nexus_endpoint(task_queue, client) + endpoint_name = make_nexus_endpoint_name(task_queue) + await env.create_nexus_endpoint(endpoint_name, task_queue) async with new_worker( client, CustomMetricsWorkflow,