diff --git a/durabletask/internal/proto_task_hub_sidecar_service_stub.py b/durabletask/internal/proto_task_hub_sidecar_service_stub.py new file mode 100644 index 0000000..8f51123 --- /dev/null +++ b/durabletask/internal/proto_task_hub_sidecar_service_stub.py @@ -0,0 +1,33 @@ +from typing import Any, Callable, Protocol + + +class ProtoTaskHubSidecarServiceStub(Protocol): + """A stub class matching the TaskHubSidecarServiceStub generated from the .proto file. + Allows the use of TaskHubGrpcWorker methods when a real sidecar stub is not available. + """ + Hello: Callable[..., Any] + StartInstance: Callable[..., Any] + GetInstance: Callable[..., Any] + RewindInstance: Callable[..., Any] + WaitForInstanceStart: Callable[..., Any] + WaitForInstanceCompletion: Callable[..., Any] + RaiseEvent: Callable[..., Any] + TerminateInstance: Callable[..., Any] + SuspendInstance: Callable[..., Any] + ResumeInstance: Callable[..., Any] + QueryInstances: Callable[..., Any] + PurgeInstances: Callable[..., Any] + GetWorkItems: Callable[..., Any] + CompleteActivityTask: Callable[..., Any] + CompleteOrchestratorTask: Callable[..., Any] + CompleteEntityTask: Callable[..., Any] + StreamInstanceHistory: Callable[..., Any] + CreateTaskHub: Callable[..., Any] + DeleteTaskHub: Callable[..., Any] + SignalEntity: Callable[..., Any] + GetEntity: Callable[..., Any] + QueryEntities: Callable[..., Any] + CleanEntityStorage: Callable[..., Any] + AbandonTaskActivityWorkItem: Callable[..., Any] + AbandonTaskOrchestratorWorkItem: Callable[..., Any] + AbandonTaskEntityWorkItem: Callable[..., Any] diff --git a/durabletask/worker.py b/durabletask/worker.py index 56687bb..a4222dd 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -24,6 +24,7 @@ from durabletask.internal.helpers import new_timestamp from durabletask.entities import DurableEntity, EntityLock, EntityInstanceId, EntityContext from durabletask.internal.orchestration_entity_context import OrchestrationEntityContext +from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub import durabletask.internal.helpers as ph import durabletask.internal.exceptions as pe import durabletask.internal.orchestrator_service_pb2 as pb @@ -631,7 +632,7 @@ def stop(self): def _execute_orchestrator( self, req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): try: @@ -679,7 +680,7 @@ def _execute_orchestrator( def _cancel_orchestrator( self, req: pb.OrchestratorRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): stub.AbandonTaskOrchestratorWorkItem( @@ -692,7 +693,7 @@ def _cancel_orchestrator( def _execute_activity( self, req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): instance_id = req.orchestrationInstance.instanceId @@ -725,7 +726,7 @@ def _execute_activity( def _cancel_activity( self, req: pb.ActivityRequest, - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): stub.AbandonTaskActivityWorkItem( @@ -738,7 +739,7 @@ def _cancel_activity( def _execute_entity_batch( self, req: Union[pb.EntityBatchRequest, pb.EntityRequest], - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): if isinstance(req, pb.EntityRequest): @@ -807,7 +808,7 @@ def _execute_entity_batch( def _cancel_entity_batch( self, req: Union[pb.EntityBatchRequest, pb.EntityRequest], - stub: stubs.TaskHubSidecarServiceStub, + stub: Union[stubs.TaskHubSidecarServiceStub, ProtoTaskHubSidecarServiceStub], completionToken, ): stub.AbandonTaskEntityWorkItem( diff --git a/tests/durabletask/test_proto_task_hub_shim.py b/tests/durabletask/test_proto_task_hub_shim.py new file mode 100644 index 0000000..8bd3a65 --- /dev/null +++ b/tests/durabletask/test_proto_task_hub_shim.py @@ -0,0 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +from typing import get_type_hints + +from durabletask.internal.orchestrator_service_pb2_grpc import TaskHubSidecarServiceStub +from durabletask.internal.proto_task_hub_sidecar_service_stub import ProtoTaskHubSidecarServiceStub + + +def test_proto_task_hub_shim_is_compatible(): + """Test that ProtoTaskHubSidecarServiceStub is compatible with TaskHubSidecarServiceStub.""" + protocol_attrs = set(get_type_hints(ProtoTaskHubSidecarServiceStub).keys()) + + # Instantiate TaskHubSidecarServiceStub with a dummy channel to get its attributes + class TestChannel(): + def unary_unary(self, *args, **kwargs): + pass + + def unary_stream(self, *args, **kwargs): + pass + impl_attrs = TaskHubSidecarServiceStub(TestChannel()).__dict__.keys() + + # Check missing + assert protocol_attrs == impl_attrs