From 71cff0d36c95d23519124d103ee90f648c190f2d Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Wed, 25 Feb 2026 14:53:09 -0800 Subject: [PATCH 01/12] External payload storage --- README.md | 119 +++- temporalio/bridge/worker.py | 7 +- temporalio/converter.py | 100 ++- temporalio/extstore.py | 656 +++++++++++++++++++ temporalio/worker/_activity.py | 2 +- temporalio/worker/_workflow.py | 51 +- temporalio/worker/_workflow_instance.py | 9 +- tests/test_extstore.py | 803 ++++++++++++++++++++++++ tests/worker/test_extstore.py | 427 +++++++++++++ 9 files changed, 2131 insertions(+), 43 deletions(-) create mode 100644 temporalio/extstore.py create mode 100644 tests/test_extstore.py create mode 100644 tests/worker/test_extstore.py diff --git a/README.md b/README.md index f26c7b837..096c138bc 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,7 @@ informal introduction to the features and their implementation. - [Data Conversion](#data-conversion) - [Pydantic Support](#pydantic-support) - [Custom Type Data Conversion](#custom-type-data-conversion) + - [External Payload Storage](#external-payload-storage) - [Workers](#workers) - [Workflows](#workflows) - [Definition](#definition) @@ -309,8 +310,9 @@ other_ns_client = Client(**config) Data converters are used to convert raw Temporal payloads to/from actual Python types. A custom data converter of type `temporalio.converter.DataConverter` can be set via the `data_converter` parameter of the `Client` constructor. Data -converters are a combination of payload converters, payload codecs, and failure converters. Payload converters convert -Python values to/from serialized bytes. Payload codecs convert bytes to bytes (e.g. for compression or encryption). +converters are a combination of payload converters, external payload storage, payload codecs, and failure converters. Payload +converters convert Python values to/from serialized bytes. External payload storage optionally stores and retrieves payloads +to/from external storage services using drivers. Payload codecs convert bytes to bytes (e.g. for compression or encryption). Failure converters convert exceptions to/from serialized failures. The default data converter supports converting multiple types including: @@ -455,6 +457,119 @@ my_data_converter = dataclasses.replace( Now `IPv4Address` can be used in type hints including collections, optionals, etc. +##### External Payload Storage + +⚠️ **External payload storage support is currently at an experimental release stage.** ⚠️ + +External payload storage allows large payloads to be offloaded to an external storage service (such as Amazon S3) rather than stored inline in workflow history. This is useful when workflows or activities work with data that would otherwise exceed Temporal's payload size limits. + +External payload storage is configured via the `external_storage` parameter on `DataConverter`, which accepts a `temporalio.extstore.Options` instance. Any driver used to store payloads must also be configured on the component that retrieves them — for example, if the client stores workflow inputs using a driver, the worker must include that driver in its `Options.drivers` list to retrieve them. + +The simplest setup uses a single storage driver: + +```python +import dataclasses +from temporalio.client import Client +from temporalio.converter import DataConverter +from temporalio.extstore import Options + +driver = MyDriver() + +client = await Client.connect( + "localhost:7233", + data_converter=dataclasses.replace( + DataConverter.default, + external_storage=Options(drivers=[driver]), + ), +) +``` + +Some things to note about external payload storage: + +* Only payloads that meet or exceed `Options.payload_size_threshold` (default 256 KiB) are offloaded. Smaller payloads are stored inline as normal. +* External payload storage applies transparently to workflow inputs/outputs, activity inputs/outputs, signals, updates, queries, and failure details. +* The `DataConverter`'s `payload_codec` (if configured) is applied to the *reference* payload stored in workflow history, not to the externally stored bytes. To encrypt or compress the bytes handed to a driver, use `Options.external_converter`. +* Setting `Options.payload_size_threshold` to `None` causes every payload to be considered for external payload storage regardless of size. + +###### Multiple Drivers and Driver Selection + +When multiple storage backends are needed, list all drivers in `Options.drivers` and provide a `driver_selector` to control which driver stores new payloads. Any driver in the list not chosen for storing is still available for retrieval, which is useful when migrating between storage backends. + +```python +from temporalio.extstore import Options + +options = Options( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: ( + hot_driver if payload.ByteSize() < 5 * 1024 * 1024 else cold_driver + ), +) +``` + +For stateful or class-based selection logic, implement `temporalio.extstore.DriverSelector`: + +```python +from temporalio.extstore import Driver, DriverContext, DriverSelector +from temporalio.api.common.v1 import Payload + +class MyDriverSelector(DriverSelector): + def select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: + # Return None to store the payload inline rather than externally + if payload.ByteSize() < 256 * 1024: + return None + return hot_driver +``` + +Some things to note about driver selection: + +* When no `driver_selector` is set, the first driver in `Options.drivers` is always used for storing. +* Returning `None` from a selector leaves the payload stored inline in workflow history rather than offloading it. +* The driver returned by the selector must be registered in `Options.drivers`. If it is not, a `DriverNotFoundError` is raised. + +###### Custom Drivers + +Implement `temporalio.extstore.Driver` to integrate with any external payload storage system: + +```python +from collections.abc import Sequence +from temporalio.extstore import Driver, DriverClaim, DriverContext +from temporalio.api.common.v1 import Payload + +class MyDriver(Driver): + def __init__(self, driver_name: str | None = None): + self._driver_name = driver_name or "my-org:driver:my-driver" + + def name(self) -> str: + return self._driver_name + + async def store( + self, context: DriverContext, payloads: Sequence[Payload] + ) -> list[DriverClaim]: + claims = [] + for payload in payloads: + key = await my_storage.put(payload.SerializeToString()) + claims.append(DriverClaim(data={"key": key})) + return claims + + async def retrieve( + self, context: DriverContext, claims: Sequence[DriverClaim] + ) -> list[Payload]: + payloads = [] + for claim in claims: + data = await my_storage.get(claim.data["key"]) + p = Payload() + p.ParseFromString(data) + payloads.append(p) + return payloads +``` + +Some things to note about implementing a custom driver: + +* `store` and `retrieve` must return lists of the same length as their respective input sequences. +* `Driver.name()` must return a string that is unique among all drivers in `Options.drivers`. This name is embedded in the reference payload stored in workflow history and used to look up the correct driver during retrieval — changing it after payloads have been stored will break retrieval. +* `Driver.type()` is automatically implemented to return the name of the class. This can be overriden in subclasses but must remain consistent across all instances of the subclass. +* Implement `temporalio.converter.WithSerializationContext` on your driver to receive workflow or activity context (namespace, workflow ID, activity ID, etc.) at serialization time. + ### Workers Workers host workflows and/or activities. Here's how to run a worker: diff --git a/temporalio/bridge/worker.py b/temporalio/bridge/worker.py index c98afefca..c2e426d28 100644 --- a/temporalio/bridge/worker.py +++ b/temporalio/bridge/worker.py @@ -303,10 +303,9 @@ async def decode_activation( decode_headers: bool, ) -> None: """Decode all payloads in the activation.""" - if data_converter._decode_payload_has_effect: - await CommandAwarePayloadVisitor( - skip_search_attributes=True, skip_headers=not decode_headers - ).visit(_Visitor(data_converter._decode_payload_sequence), activation) + await CommandAwarePayloadVisitor( + skip_search_attributes=True, skip_headers=not decode_headers + ).visit(_Visitor(data_converter._decode_payload_sequence), activation) async def encode_completion( diff --git a/temporalio/converter.py b/temporalio/converter.py index dc37f5039..b67b61873 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -21,6 +21,7 @@ from itertools import zip_longest from logging import getLogger from typing import ( + TYPE_CHECKING, Any, ClassVar, Literal, @@ -44,6 +45,9 @@ import temporalio.exceptions import temporalio.types +if TYPE_CHECKING: + from temporalio.extstore import StorageOptions, _ExternalStorageMiddleware + if sys.version_info < (3, 11): # Python's datetime.fromisoformat doesn't support certain formats pre-3.11 from dateutil import parser # type: ignore @@ -924,11 +928,42 @@ def to_failure( failure: temporalio.api.failure.v1.Failure, ) -> None: """See base class.""" + from temporalio.extstore import ( + DriverError, + PayloadNotFoundError, + ) + # If already a failure error, use that if isinstance(exception, temporalio.exceptions.FailureError): self._error_to_failure(exception, payload_converter, failure) elif isinstance(exception, nexusrpc.HandlerError): self._nexus_handler_error_to_failure(exception, payload_converter, failure) + elif isinstance(exception, PayloadNotFoundError): + # Convert to failure error + failure_error = temporalio.exceptions.ApplicationError( + str(exception), + { + "driver_name": exception.driver_name, + "driver_claim": exception.driver_claim, + }, + type=exception.__class__.__name__, + non_retryable=True, + ) + failure_error.__traceback__ = exception.__traceback__ + failure_error.__cause__ = exception.__cause__ + self._error_to_failure(failure_error, payload_converter, failure) + elif isinstance(exception, DriverError): + # Convert to failure error + failure_error = temporalio.exceptions.ApplicationError( + str(exception), + { + "driver_name": exception.driver_name, + }, + type=exception.__class__.__name__, + ) + failure_error.__traceback__ = exception.__traceback__ + failure_error.__cause__ = exception.__cause__ + self._error_to_failure(failure_error, payload_converter, failure) else: # Convert to failure error failure_error = temporalio.exceptions.ApplicationError( @@ -1359,15 +1394,27 @@ class DataConverter(WithSerializationContext): payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() """Settings for payload size limits.""" + external_storage: StorageOptions | None = None + """Options for external storage. If None, external storage is disabled. + + .. warning:: + This API is experimental. + """ + default: ClassVar[DataConverter] """Singleton default data converter.""" + _external_storage_middleware: "_ExternalStorageMiddleware" = dataclasses.field( + init=False + ) + _payload_error_limits: _ServerPayloadErrorLimits | None = None """Server-reported limits for payloads.""" def __post_init__(self) -> None: # noqa: D105 object.__setattr__(self, "payload_converter", self.payload_converter_class()) object.__setattr__(self, "failure_converter", self.failure_converter_class()) + self._reset_external_storage_middleware() async def encode( self, values: Sequence[Any] @@ -1445,18 +1492,22 @@ def with_context(self, context: SerializationContext) -> Self: payload_converter = self.payload_converter payload_codec = self.payload_codec failure_converter = self.failure_converter + external_storage = self.external_storage if isinstance(payload_converter, WithSerializationContext): payload_converter = payload_converter.with_context(context) if isinstance(payload_codec, WithSerializationContext): payload_codec = payload_codec.with_context(context) if isinstance(failure_converter, WithSerializationContext): failure_converter = failure_converter.with_context(context) + if isinstance(external_storage, WithSerializationContext): + external_storage = external_storage.with_context(context) if all( new is orig for new, orig in [ (payload_converter, self.payload_converter), (payload_codec, self.payload_codec), (failure_converter, self.failure_converter), + (external_storage, self.external_storage), ] ): return self @@ -1464,8 +1515,22 @@ def with_context(self, context: SerializationContext) -> Self: object.__setattr__(cloned, "payload_converter", payload_converter) object.__setattr__(cloned, "payload_codec", payload_codec) object.__setattr__(cloned, "failure_converter", failure_converter) + object.__setattr__(cloned, "external_storage", external_storage) + cloned._reset_external_storage_middleware(context) return cloned + def _reset_external_storage_middleware( + self, context: SerializationContext | None = None + ) -> None: + # Lazy import to avoid circular dependency + from temporalio.extstore import _ExternalStorageMiddleware + + object.__setattr__( + self, + "_external_storage_middleware", + _ExternalStorageMiddleware(self.external_storage, context), + ) + def _with_payload_error_limits( self, limits: _ServerPayloadErrorLimits | None ) -> DataConverter: @@ -1523,48 +1588,47 @@ async def _encode_memo_existing( async def _encode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: + payload = await self._external_storage_middleware.store_payload(payload) if self.payload_codec: payload = (await self.payload_codec.encode([payload]))[0] self._validate_payload_limits([payload]) return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): - if self.payload_codec: - await self.payload_codec.encode_wrapper(payloads) - self._validate_payload_limits(payloads.payloads) + encoded_payloads = await self._encode_payload_sequence(payloads.payloads) + del payloads.payloads[:] + payloads.payloads.extend(encoded_payloads) async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - encoded_payloads = list(payloads) + result = await self._external_storage_middleware.store_payloads(payloads) if self.payload_codec: - encoded_payloads = await self.payload_codec.encode(encoded_payloads) - self._validate_payload_limits(encoded_payloads) - return encoded_payloads + result = await self.payload_codec.encode(result) + self._validate_payload_limits(result) + return result async def _decode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: if self.payload_codec: payload = (await self.payload_codec.decode([payload]))[0] + payload = await self._external_storage_middleware.retrieve_payload(payload) return payload async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): - if self.payload_codec: - await self.payload_codec.decode_wrapper(payloads) + decoded_payloads = await self._decode_payload_sequence(payloads.payloads) + del payloads.payloads[:] + payloads.payloads.extend(decoded_payloads) async def _decode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - if not self.payload_codec: - return list(payloads) - return await self.payload_codec.decode(payloads) - - # Temporary shortcircuit detection while the _decode_* methods may no-op if - # a payload codec is not configured. Remove once those paths have more to them. - @property - def _decode_payload_has_effect(self) -> bool: - return self.payload_codec is not None + result = list(payloads) + if self.payload_codec: + result = await self.payload_codec.decode(result) + result = await self._external_storage_middleware.retrieve_payloads(result) + return result @staticmethod async def _apply_to_failure_payloads( diff --git a/temporalio/extstore.py b/temporalio/extstore.py new file mode 100644 index 000000000..5e9c3e1f5 --- /dev/null +++ b/temporalio/extstore.py @@ -0,0 +1,656 @@ +"""External payload storage support for offloading payloads to external storage systems.""" + +from __future__ import annotations + +import asyncio +import dataclasses +import warnings +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from typing_extensions import Self + +from temporalio.api.common.v1 import Payload +from temporalio.converter import ( + JSONPlainPayloadConverter, + SerializationContext, + WithSerializationContext, +) +from temporalio.exceptions import TemporalError + +if TYPE_CHECKING: + from temporalio.converter import PayloadCodec + + +@dataclass(frozen=True) +class DriverClaim: + """Claim for an externally stored payload. + + .. warning:: + This API is experimental. + """ + + data: dict[str, str] + """Driver-defined data for identifying and retrieving an externally stored payload.""" + + +@dataclass(frozen=True) +class DriverContext: + """Context passed to :class:`Driver` and :class:`DriverSelector` calls. + + .. warning:: + This API is experimental. + """ + + serialization_context: SerializationContext | None = None + """The serialization context active when this driver operation was initiated, + or ``None`` if no context has been set. + """ + + +class Driver(ABC): + """Base driver for storing and retrieve payloads from external storage systems. + + .. warning:: + This API is experimental. + """ + + @abstractmethod + def name(self) -> str: + """Returns the name of this driver instance. A driver may allow its name + to be parameterized at construction time so that multiple instances of + the same driver class can coexist in :attr:`StorageOptions.drivers` with + distinct names. + """ + raise NotImplementedError + + def type(self) -> str: + """Returns the type of the storage driver. This string should be the same + across all instantiations of the same driver class. This allows the equivalent + driver implementation in different languages to be named the same. + + Defaults to the class name. Subclasses may override this to return a + stable, language-agnostic identifier. + """ + return type(self).__name__ + + @abstractmethod + async def store( + self, + context: DriverContext, + payloads: Sequence[Payload], + ) -> list[DriverClaim]: + """Stores payloads in external storage and returns a :class:`DriverClaim` + for each one. The returned list must be the same length as ``payloads``. + """ + raise NotImplementedError + + @abstractmethod + async def retrieve( + self, + context: DriverContext, + claims: Sequence[DriverClaim], + ) -> list[Payload]: + """Retrieves payloads from external storage for the given :class:`DriverClaim` + list. The returned list must be the same length as ``claims``. + + Raise :class:`PayloadNotFoundError` when a retrieval attempt confirms + that a payload is absent from storage. This signals an unrecoverable + condition that will fail the workflow rather than retrying the workflow + task. + """ + raise NotImplementedError + + +class DriverSelector(ABC): + """Determines which :class:`Driver` stores a given payload. + + Implement this class and set it as :attr:`StorageOptions.driver_selector` when you + need stateful or class-based selection logic. For simple cases a plain + callable ``(DriverContext, Payload) -> Driver | None`` can be used instead. + + .. warning:: + This API is experimental. + """ + + @abstractmethod + def select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: + """Returns the driver to use to externally store the payload, or None to decline to + externally store the payload. + """ + pass + + +@dataclass(frozen=True) +class StorageConverter(WithSerializationContext): + """Converters for converting and encoding external payloads to/from Python values. + + .. warning:: + This API is experimental. + """ + + payload_codec: PayloadCodec | None + """Optional codec applied to payloads before they are handed to a + :class:`Driver` for storage, and after they are retrieved. When ``None``, + payloads are stored as-is by the driver. + """ + + def with_context(self, context: SerializationContext) -> Self: + """Return a copy of this converter with the serialization context applied. + + If :attr:`payload_codec` implements :class:`WithSerializationContext`, + a new instance is created with the context propagated to it. If nothing + changed, ``self`` is returned unchanged. + """ + payload_codec = self.payload_codec + if isinstance(payload_codec, WithSerializationContext): + payload_codec = payload_codec.with_context(context) + if payload_codec == self.payload_codec: + return self + cloned = dataclasses.replace(self) + object.__setattr__(cloned, "payload_codec", payload_codec) + return cloned + + +@dataclass(frozen=True) +class StorageOptions(WithSerializationContext): + """Configuration for external storage behavior. + + .. warning:: + This API is experimental. + """ + + drivers: Sequence[Driver] + """Drivers available for storing and retrieving payloads. At least one + driver must be provided. + + When no :attr:`driver_selector` is set, the first driver in this list is + used for all store operations. Additional drivers may be included solely to + support retrieval — for example, to download payloads that remote callers + uploaded to an external storage system that is not your primary store + driver. Drivers in this list are looked up by :meth:`Driver.name` during + retrieval, so each driver must have a unique name. + """ + + driver_selector: ( + DriverSelector | Callable[[DriverContext, Payload], Driver | None] | None + ) = None + """Controls which driver stores a given payload. Accepts either a + :class:`DriverSelector` instance or a callable of the form + ``(DriverContext, Payload) -> Driver | None``. + + When ``None``, the first driver in :attr:`drivers` is used for all store + operations. Returning ``None`` from the selector leaves the payload stored + inline rather than offloading it to external storage. + """ + + payload_size_threshold: int | None = 256 * 1024 + """Minimum payload size in bytes before external storage is considered. + Defaults to 256 KiB. Set to ``None`` to consider every payload for + external storage regardless of size. + """ + + external_converter: StorageConverter | None = None + """Converter applied to payload bytes before they are passed to a driver + for storage, and after they are retrieved. When ``None``, payload bytes are + handed to the driver without any additional encoding. Note that the + ``DataConverter``'s ``payload_codec`` is applied to the reference payload + that replaces the original in workflow history, not to the externally stored + bytes themselves. + """ + + def with_context(self, context: SerializationContext) -> Self: + """Return a copy of these options with the serialization context applied. + + Propagates *context* to any drivers, the driver selector, and the + external converter that implement :class:`WithSerializationContext`. + If none of those fields changed, ``self`` is returned unchanged. + """ + drivers = list(self.drivers) + for index, driver in enumerate(drivers): + if isinstance(driver, WithSerializationContext): + drivers[index] = driver.with_context(context) + driver_selector = self.driver_selector + if isinstance(driver_selector, WithSerializationContext): + driver_selector = driver_selector.with_context(context) + external_converter = self.external_converter + if isinstance(external_converter, WithSerializationContext): + external_converter = external_converter.with_context(context) + if all( + new is orig + for new, orig in [ + (drivers, self.drivers), + (driver_selector, self.driver_selector), + (external_converter, self.external_converter), + ] + ): + return self + cloned = dataclasses.replace(self) + object.__setattr__(cloned, "drivers", drivers) + object.__setattr__(cloned, "driver_selector", driver_selector) + object.__setattr__(cloned, "external_converter", external_converter) + return cloned + + +class DriverError(TemporalError): + """Raised when an error occurs related to a specific driver. + + .. warning:: + This API is experimental. + """ + + def __init__(self, message: str, driver_name: str) -> None: + """Initialize with an error message and the name of the driver that failed.""" + super().__init__(message) + self._driver_name = driver_name + + @property + def driver_name(self) -> str: + """Name of the driver that caused this error.""" + return self._driver_name + + +class DriverNotFoundError(DriverError): + """Raised when a driver name cannot be resolved to a driver in + :attr:`StorageOptions.drivers`. This can occur during retrieval when a + :class:`DriverClaim` references a driver name that is not present, or + during storage when the :attr:`StorageOptions.driver_selector` returns a + :class:`Driver` whose :meth:`Driver.name` is not registered. + + .. warning:: + This API is experimental. + """ + + def __init__(self, driver_name: str) -> None: + """Initialize with the name of the driver that could not be resolved.""" + super().__init__( + f"No driver found with name '{driver_name}'", driver_name=driver_name + ) + + +class PayloadNotFoundError(TemporalError): + """Raised when a payload cannot be retrieved because it does not exist + at the location indicated by its :class:`DriverClaim`. + + When raised during workflow execution this error fails the **workflow** + rather than the workflow task. Drivers should raise this when a retrieval + attempt confirms the payload is absent. + + This error is intentionally not a subclass of :class:`DriverError` to + avoid accidentally handling it and treating as a workflow task failure. + + .. warning:: + This API is experimental. + """ + + def __init__( + self, + message: str | None = None, + *, + driver_claim: DriverClaim, + driver_name: str, + ) -> None: + """Initialize a payload not found error.""" + super().__init__(message or f"Payload not found for driver '{driver_name}'") + self._driver_claim = driver_claim + self._driver_name = driver_name + + @property + def driver_claim(self) -> DriverClaim: + """The :class:`DriverClaim` for the payload that could not be found.""" + return self._driver_claim + + @property + def driver_name(self) -> str: + """Name of the driver that reported the payload as not found.""" + return self._driver_name + + +class StorageWarning(RuntimeWarning): + """Warning for external storage issues. + + .. warning:: + This API is experimental. + """ + + +@dataclass(frozen=True) +class _StorageReference: + driver_name: str + driver_claim: DriverClaim + + +class _ExternalStorageMiddleware: # type:ignore[reportUnusedClass] + # Claim payload encoding is fixed and independent of any user configuration. + _claim_converter: JSONPlainPayloadConverter = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ) + + def __init__( + self, + options: StorageOptions | None, + context: SerializationContext | None = None, + payload_codec: PayloadCodec | None = None, + ): + self._options = options + self._context = context + self._payload_codec = ( + options.external_converter.payload_codec + if options and options.external_converter + else payload_codec + ) + self._driver_map: dict[str, Driver] = {} + if options is not None: + for driver in options.drivers: + name = driver.name() + if name in self._driver_map: + warnings.warn( + f"StorageOptions.drivers contains multiple drivers with name '{name}'. " + "The first one will be used.", + category=StorageWarning, + ) + else: + self._driver_map[name] = driver + + def _select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: + """Returns the driver to use for this payload, or None to pass through.""" + assert self._options is not None + selector = self._options.driver_selector + if selector is None: + return self._options.drivers[0] if self._options.drivers else None + elif isinstance(selector, DriverSelector): + driver = selector.select_driver(context, payload) + else: + driver = selector(context, payload) + if driver is None: + return None + registered = self._driver_map.get(driver.name()) + if registered is None: + raise DriverNotFoundError(driver.name()) + return registered + + def _get_driver_by_name(self, name: str) -> Driver: + """Looks up a driver by name, raising :class:`DriverNotFoundError` if not found.""" + driver = self._driver_map.get(name) + if driver is None: + raise DriverNotFoundError(name) + return driver + + async def store_payload(self, payload: Payload) -> Payload: + if self._options is None: + return payload + + size_bytes = payload.ByteSize() + if ( + self._options.payload_size_threshold is not None + and size_bytes < self._options.payload_size_threshold + ): + return payload + + context = DriverContext(serialization_context=self._context) + + driver = self._select_driver(context, payload) + if driver is None: + return payload + + # Optionally encode the payload before externally storing it + encoded_payload = payload + if self._payload_codec: + encoded_payload = (await self._payload_codec.encode([payload]))[0] + + try: + claims = await driver.store(context, [encoded_payload]) + except Exception as err: + raise DriverError("Driver store failed", driver.name()) from err + + self._validate_claim_length(claims, expected=1, driver=driver) + + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claims[0], + ) + reference_payload = self._claim_converter.to_payload(reference) + assert reference_payload is not None + reference_payload.external_payloads.add().size_bytes = ( + encoded_payload.ByteSize() + ) + return reference_payload + + async def store_payloads( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + if self._options is None: + return list(payloads) + + if len(payloads) == 1: + return [await self.store_payload(payloads[0])] + + results = list(payloads) + context = DriverContext(serialization_context=self._context) + + # First pass: determine which payloads to store and which driver to use for each. + # Provide unencoded payloads to give maximal context information to the selector. + to_store: list[tuple[int, Payload, Driver]] = [] + for index, payload in enumerate(payloads): + size_bytes = payload.ByteSize() + if ( + self._options.payload_size_threshold is not None + and size_bytes < self._options.payload_size_threshold + ): + continue + driver = self._select_driver(context, payload) + if driver is None: + continue + to_store.append((index, payload, driver)) + + if not to_store: + return results + + # Optionally encode all payloads destined for external storage + payloads_to_encode = [payload for _, payload, _ in to_store] + encoded_payloads = payloads_to_encode + if self._payload_codec: + encoded_payloads = await self._payload_codec.encode(payloads_to_encode) + + # Group encoded payloads by driver for batched store calls + # driver -> [(original_index, encoded_payload)] + driver_groups: dict[Driver, list[tuple[int, Payload]]] = {} + for i, (orig_index, _, driver) in enumerate(to_store): + driver_groups.setdefault(driver, []).append( + (orig_index, encoded_payloads[i]) + ) + + # Store all driver groups concurrently then build reference payloads + driver_group_list = list(driver_groups.items()) + + async def _store_group( + driver: Driver, indexed_payloads: list[tuple[int, Payload]] + ) -> list[DriverClaim]: + store_batch = [p for _, p in indexed_payloads] + try: + return await driver.store(context, store_batch) + except Exception as err: + raise DriverError("Driver store failed", driver.name()) from err + + all_claims = await asyncio.gather( + *( + _store_group(driver, indexed_payloads) + for driver, indexed_payloads in driver_group_list + ) + ) + + for (driver, indexed_payloads), claims in zip(driver_group_list, all_claims): + indices = [idx for idx, _ in indexed_payloads] + sizes = [p.ByteSize() for _, p in indexed_payloads] + + self._validate_claim_length(claims, expected=len(indices), driver=driver) + + for i, claim in enumerate(claims): + reference = _StorageReference( + driver_name=driver.name(), + driver_claim=claim, + ) + reference_payload = self._claim_converter.to_payload(reference) + assert reference_payload is not None + reference_payload.external_payloads.add().size_bytes = sizes[i] + results[indices[i]] = reference_payload + + return results + + async def retrieve_payload( + self, + payload: Payload, + ) -> Payload: + if self._options is None or len(self._options.drivers) == 0: + # External storage was not configured (correctly). Warn if there are any external payloads + # since that is likely to cause downstream error when decoding or deserializing. + if len(payload.external_payloads) > 0: + if not self._options: + warnings.warn( + "External storage is not configured, but detected external storage references.", + category=StorageWarning, + ) + elif len(self._options.drivers) == 0: + warnings.warn( + "StorageOptions.drivers is empty, but detected external storage references.", + category=StorageWarning, + ) + return payload + + if len(payload.external_payloads) == 0: + return payload + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + return payload + + driver = self._get_driver_by_name(reference.driver_name) + context = DriverContext(serialization_context=self._context) + + try: + stored_payloads = await driver.retrieve(context, [reference.driver_claim]) + except PayloadNotFoundError: + raise + except Exception as err: + raise DriverError("Driver retrieve failed", driver.name()) from err + + self._validate_payload_length(stored_payloads, expected=1, driver=driver) + + if self._payload_codec: + stored_payloads = await self._payload_codec.decode(stored_payloads) + + return stored_payloads[0] + + async def retrieve_payloads( + self, + payloads: Sequence[Payload], + ) -> list[Payload]: + results = list(payloads) + + if self._options is None or len(self._options.drivers) == 0: + # External storage was not configured, but warn if there are any external payloads + # since that is likely to cause downstream error when decoding or deserializing. + if any(len(p.external_payloads) > 0 for p in payloads): + if not self._options: + warnings.warn( + "External storage is not configured, but detected external storage references.", + category=StorageWarning, + ) + elif len(self._options.drivers) == 0: + warnings.warn( + "StorageOptions.drivers is empty, but detected external storage references.", + category=StorageWarning, + ) + return results + + if len(payloads) == 1: + return [await self.retrieve_payload(payloads[0])] + + # Group claims by driver for batched retrieve calls + # driver -> [(original_index, claim)] + driver_claims: dict[Driver, list[tuple[int, DriverClaim]]] = {} + for index, payload in enumerate(payloads): + if len(payload.external_payloads) == 0: + continue + + reference = self._claim_converter.from_payload(payload, _StorageReference) + if not isinstance(reference, _StorageReference): + continue + + driver = self._get_driver_by_name(reference.driver_name) + driver_claims.setdefault(driver, []).append((index, reference.driver_claim)) + + if not driver_claims: + return results + + context = DriverContext(serialization_context=self._context) + stored_by_index: dict[int, Payload] = {} + + # Retrieve from all drivers concurrently + driver_claim_list = list(driver_claims.items()) + + async def _retrieve_group( + driver: Driver, indexed_claims: list[tuple[int, DriverClaim]] + ) -> list[Payload]: + claims_to_retrieve = [claim for _, claim in indexed_claims] + try: + return await driver.retrieve(context, claims_to_retrieve) + except PayloadNotFoundError: + raise + except Exception as err: + raise DriverError("Driver retrieve failed", driver.name()) from err + + all_stored = await asyncio.gather( + *( + _retrieve_group(driver, indexed_claims) + for driver, indexed_claims in driver_claim_list + ) + ) + + for (driver, indexed_claims), stored_payloads in zip( + driver_claim_list, all_stored + ): + indices = [idx for idx, _ in indexed_claims] + + self._validate_payload_length( + stored_payloads, + expected=len(indexed_claims), + driver=driver, + ) + + for idx, stored_payload in zip(indices, stored_payloads): + stored_by_index[idx] = stored_payload + + # Decode all retrieved payloads together if a codec is configured + retrieve_indices = sorted(stored_by_index.keys()) + stored_list = [stored_by_index[idx] for idx in retrieve_indices] + + decoded_payloads = stored_list + if self._payload_codec: + decoded_payloads = await self._payload_codec.decode(stored_list) + + for i, retrieved_payload in enumerate(decoded_payloads): + results[retrieve_indices[i]] = retrieved_payload + + return results + + def _validate_claim_length( + self, claims: Sequence[DriverClaim], expected: int, driver: Driver + ) -> None: + if len(claims) != expected: + raise DriverError( + f"Driver returned {len(claims)} claims, expected {expected}", + driver.name(), + ) + + def _validate_payload_length( + self, payloads: Sequence[Payload], expected: int, driver: Driver + ) -> None: + if len(payloads) != expected: + raise DriverError( + f"Driver returned {len(payloads)} payloads, expected {expected}", + driver.name(), + ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 4e6e06282..a895f54d2 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -625,7 +625,7 @@ async def _execute_activity( else None, ) - if self._encode_headers and data_converter._decode_payload_has_effect: + if self._encode_headers: for payload in start.header_fields.values(): payload.CopyFrom(await data_converter._decode_payload(payload)) diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index 18f5599ba..efafcb8e6 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -25,6 +25,7 @@ import temporalio.workflow from temporalio.api.enums.v1 import WorkflowTaskFailedCause from temporalio.bridge.worker import PollShutdownError +from temporalio.extstore import PayloadNotFoundError from . import _command_aware_visitor from ._interceptor import ( @@ -340,21 +341,41 @@ async def _handle_activation( "Failed handling activation on workflow with run ID %s", act.run_id ) - completion.failed.failure.SetInParent() - try: - data_converter.failure_converter.to_failure( - err, - data_converter.payload_converter, - completion.failed.failure, - ) - except Exception as inner_err: - logger.exception( - "Failed converting activation exception on workflow with run ID %s", - act.run_id, - ) - completion.failed.failure.message = ( - f"Failed converting activation exception: {inner_err}" - ) + if isinstance(err, PayloadNotFoundError): + # Fail the workflow execution terminally rather than failing the task + command = completion.successful.commands.add() + failure = command.fail_workflow_execution.failure + failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) + else: + completion.failed.failure.SetInParent() + try: + data_converter.failure_converter.to_failure( + err, + data_converter.payload_converter, + completion.failed.failure, + ) + except Exception as inner_err: + logger.exception( + "Failed converting activation exception on workflow with run ID %s", + act.run_id, + ) + completion.failed.failure.message = ( + f"Failed converting activation exception: {inner_err}" + ) completion.run_id = act.run_id diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 9c22c05ce..81ebca38c 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -57,6 +57,7 @@ import temporalio.common import temporalio.converter import temporalio.exceptions +import temporalio.extstore import temporalio.workflow from temporalio.service import __version__ @@ -1781,15 +1782,17 @@ def workflow_set_current_details(self, details: str): self._current_details = details def workflow_is_failure_exception(self, err: BaseException) -> bool: - # An exception is a failure instead of a task fail if it's already a - # failure error or if it is a timeout error or if it is an instance of - # any of the failure types in the worker or workflow-level setting + # An exception causes the workflow to fail (rather than the task) if it + # is already a failure error, a timeout error, a PayloadNotFoundError + # (unrecoverable missing external payload), or an instance of any of the + # failure exception types configured at the worker or workflow level. wf_failure_exception_types = self._defn.failure_exception_types if self._dynamic_failure_exception_types is not None: wf_failure_exception_types = self._dynamic_failure_exception_types return ( isinstance(err, temporalio.exceptions.FailureError) or isinstance(err, asyncio.TimeoutError) + or isinstance(err, temporalio.extstore.PayloadNotFoundError) or any(isinstance(err, typ) for typ in wf_failure_exception_types) or any( isinstance(err, typ) diff --git a/tests/test_extstore.py b/tests/test_extstore.py new file mode 100644 index 000000000..b80c2aa45 --- /dev/null +++ b/tests/test_extstore.py @@ -0,0 +1,803 @@ +"""Tests for external storage functionality.""" + +import warnings +from collections.abc import Sequence + +import pytest +from typing_extensions import Self + +from temporalio.api.common.v1 import Payload +from temporalio.converter import ( + ActivitySerializationContext, + DataConverter, + JSONPlainPayloadConverter, + PayloadCodec, + SerializationContext, + WithSerializationContext, + WorkflowSerializationContext, +) +from temporalio.exceptions import ApplicationError, FailureError, TemporalError +from temporalio.extstore import ( + Driver, + DriverClaim, + DriverContext, + DriverError, + DriverNotFoundError, + DriverSelector, + PayloadNotFoundError, + StorageConverter, + StorageOptions, + StorageWarning, + _StorageReference, +) + + +class InMemoryTestDriver(Driver): + """In-memory storage driver for testing.""" + + def __init__( + self, + driver_name: str = "test-driver", + ): + self._driver_name = driver_name + self._storage: dict[str, bytes] = {} + self._store_calls = 0 + self._retrieve_calls = 0 + + def name(self) -> str: + return self._driver_name + + async def store( + self, + context: DriverContext, + payloads: Sequence[Payload], + ) -> list[DriverClaim]: + self._store_calls += 1 + start_index = len(self._storage) + + entries = [ + (f"payload-{start_index + i}", payload.SerializeToString()) + for i, payload in enumerate(payloads) + ] + self._storage.update(entries) + + return [DriverClaim(data={"key": key}) for key, _ in entries] + + async def retrieve( + self, + context: DriverContext, + claims: Sequence[DriverClaim], + ) -> list[Payload]: + self._retrieve_calls += 1 + + def parse_claim( + claim: DriverClaim, + ) -> Payload: + key = claim.data["key"] + if key not in self._storage: + raise PayloadNotFoundError(driver_claim=claim, driver_name=self.name()) + payload = Payload() + payload.ParseFromString(self._storage[key]) + return payload + + return [parse_claim(claim) for claim in claims] + + +class WorkflowIdFeatureFlagDriverSelector(DriverSelector, WithSerializationContext): + """Example selector that conditionally stores based on workflow ID feature flag.""" + + def __init__(self, driver: Driver, enabled: bool = False): + self._driver = driver + self._enabled = enabled + + def select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: + return self._driver if self._enabled else None + + def with_context(self, context: SerializationContext) -> Self: + workflow_id = None + if isinstance(context, ActivitySerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + if isinstance(context, WorkflowSerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + + # Create new instance with updated enabled flag and propagate context to inner driver + driver = self._driver + if isinstance(driver, WithSerializationContext): + driver = driver.with_context(context) + + return type(self)( + driver, WorkflowIdFeatureFlagDriverSelector.feature_flag_is_on(workflow_id) + ) + + @staticmethod + def feature_flag_is_on(workflow_id: str | None) -> bool: + """Mock implementation of a feature flag based on a workflow ID.""" + return workflow_id is not None and len(workflow_id) % 2 == 0 + + +class TestDataConverterExternalStorage: + """Tests for DataConverter with external storage.""" + + async def test_extstore_encode_decode(self): + """Test that large payloads are stored externally.""" + driver = InMemoryTestDriver() + + # Configure with 100-byte threshold + converter = DataConverter( + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=100, + ) + ) + + # Small value should not be externalized + small_value = "small" + encoded_small = await converter.encode([small_value]) + assert len(encoded_small) == 1 + assert not encoded_small[0].external_payloads # Not externalized + assert driver._store_calls == 0 + + # Large value should be externalized + large_value = "x" * 200 + encoded_large = await converter.encode([large_value]) + assert len(encoded_large) == 1 + assert len(encoded_large[0].external_payloads) > 0 # Externalized + assert driver._store_calls == 1 + + # Decode large value + decoded = await converter.decode(encoded_large, [str]) + assert len(decoded) == 1 + assert decoded[0] == large_value + assert driver._retrieve_calls == 1 + + async def test_extstore_reference_structure(self): + """Test that external storage creates proper reference structure.""" + converter = DataConverter( + external_storage=StorageOptions( + drivers=[InMemoryTestDriver("test-driver")], + payload_size_threshold=50, + ) + ) + + # Create large payload + large_value = "x" * 100 + encoded = await converter.encode([large_value]) + + # Verify reference structure + reference_payload = encoded[0] + assert len(reference_payload.external_payloads) > 0 + + # The payload should contain a serialized _ExternalStorageReference + # Deserialize it to verify structure using the same encoding + claim_converter = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ) + reference = claim_converter.from_payload(reference_payload, _StorageReference) + + assert isinstance(reference, _StorageReference) + assert "test-driver" == reference.driver_name + assert isinstance(reference.driver_claim, DriverClaim) + assert "key" in reference.driver_claim.data + + async def test_extstore_composite_conditional(self): + """Test using multiple drivers based on size.""" + hot_driver = InMemoryTestDriver("hot-storage") + cold_driver = InMemoryTestDriver("cold-storage") + + options = StorageOptions( + drivers=[hot_driver, cold_driver], + driver_selector=lambda context, payload: hot_driver + if payload.ByteSize() < 500 + else cold_driver, + payload_size_threshold=100, + ) + converter = DataConverter(external_storage=options) + + # Small payload (not externalized) + small = "x" * 50 + encoded_small = await converter.encode([small]) + assert not encoded_small[0].external_payloads + assert hot_driver._store_calls == 0 + assert cold_driver._store_calls == 0 + + # Medium payload (hot storage) + medium = "x" * 200 + encoded_medium = await converter.encode([medium]) + assert len(encoded_medium[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 + assert cold_driver._store_calls == 0 + + # Large payload (cold storage) + large = "x" * 2000 + encoded_large = await converter.encode([large]) + assert len(encoded_large[0].external_payloads) > 0 + assert hot_driver._store_calls == 1 # Unchanged + assert cold_driver._store_calls == 1 + + # Verify retrieval from correct drivers + decoded_medium = await converter.decode(encoded_medium, [str]) + assert decoded_medium[0] == medium + assert hot_driver._retrieve_calls == 1 + + decoded_large = await converter.decode(encoded_large, [str]) + assert decoded_large[0] == large + assert cold_driver._retrieve_calls == 1 + + async def test_extstore_serialization_context(self): + driver = InMemoryTestDriver() + + # The payload should not be stored externally when it doesn't have a serialization context + # or if the workflow ID doesn't end with "-extstore". This is an example of feature flagging + # external storage using the workflow ID. This is an advanced secnario and requires the "can_store" + # filter to be a WithSerializationContext. + options = StorageOptions( + drivers=[driver], + driver_selector=WorkflowIdFeatureFlagDriverSelector(driver), + payload_size_threshold=1024, + ) + converter = DataConverter(external_storage=options) + + large_value = "c" * (1024 + 100) + + encoded_payloads = await converter.encode([large_value]) + assert len(encoded_payloads) == 1 + assert driver._store_calls == 0 + assert driver._retrieve_calls == 0 + + decoded_values = await converter.decode(encoded_payloads, [str]) + assert len(decoded_values) == 1 + assert driver._store_calls == 0 + assert driver._store_calls == 0 + + namespace = "my-ns" + + # Now has serialization context, but workflow ID does not end with "-extstore" + converter = converter.with_context( + WorkflowSerializationContext( + namespace=namespace, + workflow_id="odd-length-workflow-id1", + ) + ) + + encoded_payloads = await converter.encode([large_value]) + assert len(encoded_payloads) == 1 + assert driver._store_calls == 0 + assert driver._retrieve_calls == 0 + + decoded_values = await converter.decode(encoded_payloads, [str]) + assert len(decoded_values) == 1 + assert driver._store_calls == 0 + assert driver._store_calls == 0 + + # Now has serialization context with workflow ID ending with "-extstore" + converter = converter.with_context( + WorkflowSerializationContext( + namespace=namespace, + workflow_id="even-length-workflow-id1", + ) + ) + + encoded_payloads = await converter.encode([large_value]) + assert len(encoded_payloads) == 1 + assert driver._store_calls == 1 + assert driver._retrieve_calls == 0 + + decoded_values = await converter.decode(encoded_payloads, [str]) + assert len(decoded_values) == 1 + assert driver._store_calls == 1 + assert driver._store_calls == 1 + + +class NotFoundDriver(Driver): + """Driver that stores normally but raises PayloadNotFoundError on retrieve.""" + + def __init__(self, driver_name: str = "not-found-driver"): + self._driver_name = driver_name + self._storage: dict[str, bytes] = {} + + def name(self) -> str: + return self._driver_name + + async def store( + self, + context: DriverContext, + payloads: Sequence[Payload], + ) -> list[DriverClaim]: + entries = [ + (f"payload-{i}", payload.SerializeToString()) + for i, payload in enumerate(payloads) + ] + self._storage.update(entries) + return [DriverClaim(data={"key": key}) for key, _ in entries] + + async def retrieve( + self, + context: DriverContext, + claims: Sequence[DriverClaim], + ) -> list[Payload]: + assert len(claims) > 0, "NotFoundDriver expected claims to be provided" + raise PayloadNotFoundError( + "Payload not found in not-found-driver", + driver_claim=claims[0], + driver_name=self.name(), + ) + + +class TestPayloadNotFoundError: + """Tests for PayloadNotFoundError class and middleware behaviour.""" + + def test_class_hierarchy(self): + """PayloadNotFoundError must be TemporalError but not ApplicationError or FailureError.""" + assert issubclass(PayloadNotFoundError, TemporalError) + assert not issubclass(PayloadNotFoundError, ApplicationError) + assert not issubclass(PayloadNotFoundError, FailureError) + assert not issubclass(PayloadNotFoundError, DriverError) + + def test_default_message(self): + claim = DriverClaim(data={"key": "my-key"}) + err = PayloadNotFoundError(driver_claim=claim, driver_name="my-driver") + assert str(err) == "Payload not found for driver 'my-driver'" + + def test_properties(self): + claim = DriverClaim(data={"key": "my-key"}) + err = PayloadNotFoundError("gone", driver_claim=claim, driver_name="my-driver") + assert err.driver_claim is claim + assert err.driver_name == "my-driver" + + async def test_middleware_propagates_not_found(self): + """PayloadNotFoundError from a driver must not be wrapped in DriverError.""" + converter = DataConverter( + external_storage=StorageOptions( + drivers=[NotFoundDriver()], + payload_size_threshold=1, # store everything + ) + ) + + # Store a payload so we have a reference to retrieve + encoded = await converter.encode(["hello world " * 20]) + assert len(encoded[0].external_payloads) > 0 + + # Retrieving should raise PayloadNotFoundError, not DriverError + with pytest.raises(PayloadNotFoundError): + await converter.decode(encoded, [str]) + + +class TestDriverError: + """Tests for DriverError raised when a driver violates its contract.""" + + async def test_encode_wrong_claim_count_raises_driver_error(self): + """store() returning fewer claims than payloads must raise DriverError.""" + + class _NoClaimsDriver(InMemoryTestDriver): + async def store( + self, context: DriverContext, payloads: Sequence[Payload] + ) -> list[DriverClaim]: + return [] + + converter = DataConverter( + external_storage=StorageOptions( + drivers=[_NoClaimsDriver()], + payload_size_threshold=10, + ) + ) + with pytest.raises(DriverError, match="Driver returned 0 claims, expected 1"): + await converter.encode(["x" * 200]) + + async def test_decode_wrong_payload_count_raises_driver_error(self): + """retrieve() returning fewer payloads than claims must raise DriverError.""" + good_converter = DataConverter( + external_storage=StorageOptions( + drivers=[InMemoryTestDriver()], + payload_size_threshold=10, + ) + ) + encoded = await good_converter.encode(["x" * 200]) + + class _NoPayloadsDriver(InMemoryTestDriver): + async def retrieve( + self, context: DriverContext, claims: Sequence[DriverClaim] + ) -> list[Payload]: + return [] + + bad_converter = DataConverter( + external_storage=StorageOptions( + drivers=[ + _NoPayloadsDriver() + ], # same default name as InMemoryTestDriver + payload_size_threshold=10, + ) + ) + with pytest.raises(DriverError, match="Driver returned 0 payloads, expected 1"): + await bad_converter.decode(encoded, [str]) + + async def test_encode_driver_exception_wrapped_in_driver_error(self): + """Exception raised by store() must be wrapped in DriverError.""" + + class _StoreError(Exception): + pass + + class _RaisingStoreDriver(InMemoryTestDriver): + async def store( + self, context: DriverContext, payloads: Sequence[Payload] + ) -> list[DriverClaim]: + raise _StoreError("store failed") + + converter = DataConverter( + external_storage=StorageOptions( + drivers=[_RaisingStoreDriver()], + payload_size_threshold=10, + ) + ) + with pytest.raises(DriverError) as exc_info: + await converter.encode(["x" * 200]) + assert isinstance(exc_info.value.__cause__, _StoreError) + + async def test_decode_driver_exception_wrapped_in_driver_error(self): + """Exception raised by retrieve() must be wrapped in DriverError.""" + + class _RetrieveError(Exception): + pass + + class _RaisingRetrieveDriver(InMemoryTestDriver): + async def retrieve( + self, context: DriverContext, claims: Sequence[DriverClaim] + ) -> list[Payload]: + raise _RetrieveError("retrieve failed") + + good_converter = DataConverter( + external_storage=StorageOptions( + drivers=[InMemoryTestDriver()], + payload_size_threshold=10, + ) + ) + encoded = await good_converter.encode(["x" * 200]) + + bad_converter = DataConverter( + external_storage=StorageOptions( + drivers=[ + _RaisingRetrieveDriver() + ], # same default name as InMemoryTestDriver + payload_size_threshold=10, + ) + ) + with pytest.raises(DriverError) as exc_info: + await bad_converter.decode(encoded, [str]) + assert isinstance(exc_info.value.__cause__, _RetrieveError) + + +class RecordingPayloadCodec(PayloadCodec): + """Codec that wraps each payload under a recognisable ``encoding`` label. + + Encode sets ``metadata["encoding"]`` to ``encoding_label`` and stores the + serialised inner payload as ``data``. Decode reverses that. The call + counters let tests assert exactly how many payloads each codec processed. + """ + + def __init__(self, encoding_label: str) -> None: + self._encoding_label = encoding_label.encode() + self.encoded_count = 0 + self.decoded_count = 0 + + async def encode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.encoded_count += len(payloads) + results = [] + for p in payloads: + wrapped = Payload() + wrapped.metadata["encoding"] = self._encoding_label + wrapped.data = p.SerializeToString() + results.append(wrapped) + return results + + async def decode(self, payloads: Sequence[Payload]) -> list[Payload]: + self.decoded_count += len(payloads) + results = [] + for p in payloads: + inner = Payload() + inner.ParseFromString(p.data) + results.append(inner) + return results + + +class TestPayloadCodecWithExternalStorage: + """Tests for interaction between DataConverter.payload_codec and external storage.""" + + async def test_dc_payload_codec_encodes_reference_payload(self): + """DataConverter.payload_codec encodes the reference payload in workflow + history but does NOT encode the bytes handed to the driver for storage.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=50, + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # The reference payload written to history must carry the dc_codec label. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # The bytes given to the driver must NOT carry the dc_codec label. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") != b"binary/dc-encoded" + assert stored_payload.metadata.get("encoding") == b"json/plain" + + # Round-trip must recover the original value. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + async def test_external_converter_without_codec_does_not_encode_stored_bytes(self): + """When DataConverter.payload_codec is set but StorageOptions.external_converter + has no payload_codec, stored bytes are NOT encoded – even though + DataConverter.payload_codec is active for the reference payload in history.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=50, + # Explicitly set external_converter without its own codec. + # DataConverter.payload_codec must NOT bleed through to stored bytes. + external_converter=StorageConverter(payload_codec=None), + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # Reference payload in history is still encoded by DataConverter.payload_codec. + assert dc_codec.encoded_count == 1 + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # Stored bytes are NOT encoded. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") != b"binary/dc-encoded" + assert stored_payload.metadata.get("encoding") == b"json/plain" + + # Round-trip. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + async def test_external_converter_codec_independent_from_dc_codec(self): + """When both DataConverter.payload_codec and + StorageOptions.external_converter.payload_codec are set, the reference + payload in history uses DataConverter.payload_codec and the bytes stored + by the driver use external_converter.payload_codec – independently.""" + driver = InMemoryTestDriver() + dc_codec = RecordingPayloadCodec("binary/dc-encoded") + ext_codec = RecordingPayloadCodec("binary/ext-encoded") + + converter = DataConverter( + payload_codec=dc_codec, + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=50, + external_converter=StorageConverter(payload_codec=ext_codec), + ), + ) + + large_value = "x" * 200 + encoded = await converter.encode([large_value]) + assert len(encoded) == 1 + assert driver._store_calls == 1 + + # Each codec was applied exactly once during encode. + assert dc_codec.encoded_count == 1 + assert ext_codec.encoded_count == 1 + + # Reference payload carries dc_codec's label. + assert encoded[0].metadata.get("encoding") == b"binary/dc-encoded" + + # Stored bytes carry ext_codec's label – different from the reference. + stored_payload = Payload() + stored_payload.ParseFromString(next(iter(driver._storage.values()))) + assert stored_payload.metadata.get("encoding") == b"binary/ext-encoded" + assert stored_payload.metadata.get("encoding") != encoded[0].metadata.get( + "encoding" + ) + + # Round-trip must recover the original value using both codecs. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large_value + assert dc_codec.decoded_count == 1 + assert ext_codec.decoded_count == 1 + assert driver._retrieve_calls == 1 + + +class TestMultiDriver: + """Tests for StorageOptions with multiple drivers.""" + + async def test_no_selector_uses_first_driver_for_store(self): + """Without a driver_selector the first driver in the list handles all + store operations. Additional drivers are never called for store.""" + first = InMemoryTestDriver("driver-first") + second = InMemoryTestDriver("driver-second") + + converter = DataConverter( + external_storage=StorageOptions( + drivers=[first, second], + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert first._store_calls == 1 + assert second._store_calls == 0 + + # The reference in history names the first driver. + ref = JSONPlainPayloadConverter( + encoding="json/external-storage-reference" + ).from_payload(encoded[0], _StorageReference) + assert ref.driver_name == "driver-first" + + # Retrieval also goes to the first driver. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert first._retrieve_calls == 1 + assert second._retrieve_calls == 0 + + async def test_no_selector_second_driver_is_retrieve_only(self): + """A driver that is second in the list acts as a retrieve-only driver. + References are resolved by name, not by position, so a payload stored + by driver-b is retrieved correctly even when driver-a is listed first.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Store with driver-b as the sole driver. + store_converter = DataConverter( + external_storage=StorageOptions( + drivers=[driver_b], + payload_size_threshold=50, + ) + ) + large = "y" * 200 + encoded = await store_converter.encode([large]) + + # Retrieve with driver-a listed first, driver-b second. + # The "driver-b" name in the reference must route to driver-b. + retrieve_converter = DataConverter( + external_storage=StorageOptions( + drivers=[driver_a, driver_b], + payload_size_threshold=50, + ) + ) + decoded = await retrieve_converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver_a._retrieve_calls == 0 # never consulted + assert driver_b._retrieve_calls == 1 + + async def test_selector_routes_payloads_to_different_drivers_in_single_batch(self): + """When a selector routes different payloads to different drivers, a + single encode([v1, v2, ...]) call batches payloads per driver so each + driver receives exactly one store() call regardless of how many + payloads are routed to it.""" + driver_a = InMemoryTestDriver("driver-a") + driver_b = InMemoryTestDriver("driver-b") + + # Route payloads that serialise to < 500 bytes to driver_a, larger ones + # to driver_b. + def selector(_ctx: object, payload: Payload) -> InMemoryTestDriver: + return driver_a if payload.ByteSize() < 500 else driver_b + + converter = DataConverter( + external_storage=StorageOptions( + drivers=[driver_a, driver_b], + driver_selector=selector, + payload_size_threshold=50, + ) + ) + + small_ext = "a" * 100 # above threshold, serialises well below 500 B + large_ext = "b" * 1000 # serialises above 500 B + + # Encode both values in a single call — they should be batched per driver. + encoded = await converter.encode([small_ext, large_ext]) + assert driver_a._store_calls == 1 # one batched call, not two individual ones + assert driver_b._store_calls == 1 + + # Full round-trip. + decoded = await converter.decode(encoded, [str, str]) + assert decoded == [small_ext, large_ext] + assert driver_a._retrieve_calls == 1 + assert driver_b._retrieve_calls == 1 + + async def test_selector_returning_none_keeps_payload_inline(self): + """A selector that returns None for a payload leaves it stored inline + in workflow history rather than offloading it to any driver, even when + the payload exceeds the size threshold.""" + driver = InMemoryTestDriver("driver-a") + + converter = DataConverter( + external_storage=StorageOptions( + drivers=[driver], + driver_selector=lambda _ctx, _payload: None, + payload_size_threshold=50, + ) + ) + + large = "x" * 200 + encoded = await converter.encode([large]) + + assert driver._store_calls == 0 + assert len(encoded[0].external_payloads) == 0 # payload is inline + + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert driver._retrieve_calls == 0 + + async def test_selector_returns_unregistered_driver_raises(self): + """A selector that returns a Driver whose name is not present in + StorageOptions.drivers raises DriverNotFoundError during encode.""" + registered = InMemoryTestDriver("registered") + unregistered = InMemoryTestDriver("not-in-list") + + converter = DataConverter( + external_storage=StorageOptions( + drivers=[registered], + driver_selector=lambda _ctx, _payload: unregistered, + payload_size_threshold=50, + ) + ) + + with pytest.raises(DriverNotFoundError) as exc_info: + await converter.encode(["x" * 200]) + assert exc_info.value.driver_name == "not-in-list" + + async def test_duplicate_driver_names_warns_and_first_wins_for_retrieval(self): + """Registering two drivers with identical names emits StorageWarning. + The first-registered driver with that name is kept in the name→driver + map; subsequent duplicates are ignored. + + Both store (positional via drivers[0]) and retrieval (name-based map) + therefore resolve to the same driver — no data is lost.""" + first = InMemoryTestDriver("dup-name") + duplicate = InMemoryTestDriver("dup-name") + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + converter = DataConverter( + external_storage=StorageOptions( + drivers=[first, duplicate], + payload_size_threshold=50, + ) + ) + + storage_warnings = [w for w in caught if issubclass(w.category, StorageWarning)] + assert len(storage_warnings) == 1 + assert "dup-name" in str(storage_warnings[0].message) + + # Store goes to first (drivers[0], positional). + large = "x" * 200 + encoded = await converter.encode([large]) + assert first._store_calls == 1 + assert duplicate._store_calls == 0 + + # Retrieval resolves "dup-name" to first (first-registered wins). + # first has the data, so the round-trip succeeds. + decoded = await converter.decode(encoded, [str]) + assert decoded[0] == large + assert first._retrieve_calls == 1 + assert duplicate._retrieve_calls == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py new file mode 100644 index 000000000..a53d1334f --- /dev/null +++ b/tests/worker/test_extstore.py @@ -0,0 +1,427 @@ +import dataclasses +import uuid +from collections.abc import Sequence +from dataclasses import dataclass +from datetime import timedelta + +import pytest + +import temporalio +import temporalio.converter +from temporalio import activity, workflow +from temporalio.api.common.v1 import Payload +from temporalio.client import Client, WorkflowFailureError, WorkflowHandle +from temporalio.common import RetryPolicy +from temporalio.exceptions import ActivityError, ApplicationError +from temporalio.extstore import ( + DriverClaim, + DriverContext, + PayloadNotFoundError, + StorageOptions, + StorageWarning, +) +from temporalio.testing._workflow import WorkflowEnvironment +from temporalio.worker import Replayer +from tests.helpers import assert_task_fail_eventually, new_worker +from tests.test_extstore import InMemoryTestDriver + + +@dataclass(frozen=True) +class ExtStoreActivityInput: + input_data: str + output_size: int + pass + + +@activity.defn +async def ext_store_activity( + input: ExtStoreActivityInput, +) -> str: + return "ao" * int(input.output_size / 2) + + +@dataclass(frozen=True) +class ExtStoreWorkflowInput: + input_data: str + activity_input_size: int + activity_output_size: int + output_size: int + max_activity_attempts: int | None = None + + +@workflow.defn +class ExtStoreWorkflow: + @workflow.run + async def run(self, input: ExtStoreWorkflowInput) -> str: + retry_policy = ( + RetryPolicy(maximum_attempts=input.max_activity_attempts) + if input.max_activity_attempts is not None + else None + ) + await workflow.execute_activity( + ext_store_activity, + ExtStoreActivityInput( + input_data="ai" * int(input.activity_input_size / 2), + output_size=input.activity_output_size, + ), + schedule_to_close_timeout=timedelta(seconds=3), + retry_policy=retry_policy, + ) + return "wo" * int(input.output_size / 2) + + +class BadTestDriver(InMemoryTestDriver): + def __init__( + self, + driver_name: str = "bad-driver", + no_store: bool = False, + no_retrieve: bool = False, + raise_payload_not_found: bool = False, + ): + super().__init__(driver_name) + self._no_store = no_store + self._no_retrieve = no_retrieve + self._raise_payload_not_found = raise_payload_not_found + + async def store( + self, + context: DriverContext, + payloads: Sequence[Payload], + ) -> list[DriverClaim]: + if self._no_store: + return [] + return await super().store(context, payloads) + + async def retrieve( + self, + context: DriverContext, + claims: Sequence[DriverClaim], + ) -> list[Payload]: + if self._no_retrieve: + return [] + if self._raise_payload_not_found: + raise PayloadNotFoundError( + driver_claim=claims[0], + driver_name=self.name(), + ) + return await super().retrieve(context, claims) + + +async def test_extstore_activity_input_no_retrieve( + env: WorkflowEnvironment, +): + """Using a small threshold, validate that activity result size over + the threshold causes driver to be invoked.""" + driver = BadTestDriver(no_retrieve=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=1000, + activity_output_size=10, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + + +async def test_extstore_activity_result_no_store( + env: WorkflowEnvironment, +): + """Using a small threshold, validate that activity result size over + the threshold causes driver to be invoked.""" + driver = BadTestDriver(no_store=True) + + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="workflow input", + activity_input_size=10, + activity_output_size=1000, + output_size=10, + max_activity_attempts=1, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + with pytest.raises(WorkflowFailureError) as err: + await handle.result() + + assert isinstance(err.value.cause, ActivityError) + + +async def test_extstore_worker_missing_driver( + env: WorkflowEnvironment, +): + """Validate that when a worker is provided a workflow history with + external storage references and the worker is not configured for external + storage, it will cause a workflow task failure. + """ + driver = InMemoryTestDriver() + + far_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=1024, + ), + ), + ) + + worker_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + ) + + async with new_worker( + worker_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await far_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 1024, + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + await assert_task_fail_eventually(handle) + + +async def test_extstore_payload_not_found_fails_workflow( + env: WorkflowEnvironment, +): + """When a PayloadNotFoundError is raised while retrieving workflow input, + the workflow must fail terminally (not retry as a task failure). + """ + client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageOptions( + drivers=[BadTestDriver(raise_payload_not_found=True)], + payload_size_threshold=1024, + ), + ), + ) + + async with new_worker( + client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data="wi" * 512, # exceeds 1024-byte threshold + activity_input_size=10, + activity_output_size=10, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + + with pytest.raises(WorkflowFailureError) as exc_info: + await handle.result() + + assert isinstance(exc_info.value.cause, ApplicationError) + assert exc_info.value.cause.type == "PayloadNotFoundError" + assert exc_info.value.cause.non_retryable is True + + +async def _run_extstore_workflow_and_fetch_history( + env: WorkflowEnvironment, + driver: InMemoryTestDriver, + *, + input_data: str, + activity_output_size: int = 10, +) -> WorkflowHandle: + """Helper: run ExtStoreWorkflow with the given driver and return its history handle.""" + extstore_client = await Client.connect( + env.client.service_client.config.target_host, + namespace=env.client.namespace, + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ) + async with new_worker( + extstore_client, ExtStoreWorkflow, activities=[ext_store_activity] + ) as worker: + handle = await extstore_client.start_workflow( + ExtStoreWorkflow.run, + ExtStoreWorkflowInput( + input_data=input_data, + activity_input_size=10, + activity_output_size=activity_output_size, + output_size=10, + ), + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + await handle.result() + return handle + + +async def test_replay_extstore_history_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input fails to replay when the + Replayer has no external storage configured.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="wi" * 512, # exceeds 512-byte threshold + ) + history = await handle.fetch_history() + + # Replay without external storage — the reference payload cannot be decoded. + # The middleware emits a StorageWarning when it encounters a reference payload + # with no driver configured. + with pytest.warns(StorageWarning, match="External storage is not configured"): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Must be a task-failure RuntimeError, not a NondeterminismError — external + # storage decode failures are distinct from workflow code changes. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0] + + +async def test_replay_extstore_history_succeeds_with_correct_extstore( + env: WorkflowEnvironment, +) -> None: + """A history with externalized workflow input replays successfully when the + Replayer is configured with the same storage driver that holds the data.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with the same populated driver — must succeed. + await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageOptions( + drivers=[driver], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history) + + +async def test_replay_extstore_history_fails_with_empty_driver( + env: WorkflowEnvironment, +) -> None: + """A history with external storage references fails to replay when the + Replayer has external storage configured but the driver holds no data + (simulates pointing at the wrong backend or a purged store).""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, driver, input_data="wi" * 512 + ) + history = await handle.fetch_history() + + # Replay with a fresh empty driver — retrieval will fail. + result = await Replayer( + workflows=[ExtStoreWorkflow], + data_converter=dataclasses.replace( + temporalio.converter.default(), + external_storage=StorageOptions( + drivers=[InMemoryTestDriver()], + payload_size_threshold=512, + ), + ), + ).replay_workflow(history, raise_on_replay_failure=False) + # InMemoryTestDriver raises PayloadNotFoundError for absent keys. + # PayloadNotFoundError is re-raised without wrapping, so it propagates + # through decode_activation (before the workflow task runs). The core SDK + # receives an activation failure, issues a FailWorkflow command, but the + # next history event is ActivityTaskScheduled — causing a NondeterminismError. + assert isinstance(result.replay_failure, workflow.NondeterminismError) + + +async def test_replay_extstore_activity_result_fails_without_extstore( + env: WorkflowEnvironment, +) -> None: + """A history where only the activity result was stored externally (the + workflow input is small enough to be inline) also fails to replay without + external storage — verifying that mid-workflow decode failures are caught.""" + driver = InMemoryTestDriver() + handle = await _run_extstore_workflow_and_fetch_history( + env, + driver, + input_data="small", # well under 512 bytes — stays inline + activity_output_size=2048, # 2 KB result — stored externally + ) + history = await handle.fetch_history() + + # Replay without external storage. The workflow input decodes fine, but + # when the ActivityTaskCompleted result is delivered back to the workflow + # coroutine it cannot be decoded. + with pytest.warns(StorageWarning, match="External storage is not configured"): + result = await Replayer(workflows=[ExtStoreWorkflow]).replay_workflow( + history, raise_on_replay_failure=False + ) + # Mid-workflow decode failure is still a task failure (RuntimeError), not + # nondeterminism. + assert isinstance(result.replay_failure, RuntimeError) + assert not isinstance(result.replay_failure, workflow.NondeterminismError) + # The message is the full activation-completion failure string; the + # "Failed decoding arguments" text from _convert_payloads is embedded in it. + assert "Failed decoding arguments" in result.replay_failure.args[0] From 086640b9e882948ea3c3e86f6952934f4bd51cf6 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:01:43 -0800 Subject: [PATCH 02/12] Remove external converter in favor of payload codec --- README.md | 2 +- temporalio/extstore.py | 82 ++++++++++-------------------------------- tests/test_extstore.py | 18 ++++------ 3 files changed, 27 insertions(+), 75 deletions(-) diff --git a/README.md b/README.md index 096c138bc..4daa87ad3 100644 --- a/README.md +++ b/README.md @@ -488,7 +488,7 @@ Some things to note about external payload storage: * Only payloads that meet or exceed `Options.payload_size_threshold` (default 256 KiB) are offloaded. Smaller payloads are stored inline as normal. * External payload storage applies transparently to workflow inputs/outputs, activity inputs/outputs, signals, updates, queries, and failure details. -* The `DataConverter`'s `payload_codec` (if configured) is applied to the *reference* payload stored in workflow history, not to the externally stored bytes. To encrypt or compress the bytes handed to a driver, use `Options.external_converter`. +* The `DataConverter`'s `payload_codec` (if configured) is applied to the *reference* payload stored in workflow history, not to the externally stored bytes. To encrypt or compress the bytes handed to a driver, use `Options.payload_codec`. * Setting `Options.payload_size_threshold` to `None` causes every payload to be considered for external payload storage regardless of size. ###### Multiple Drivers and Driver Selection diff --git a/temporalio/extstore.py b/temporalio/extstore.py index 5e9c3e1f5..af6fafe6f 100644 --- a/temporalio/extstore.py +++ b/temporalio/extstore.py @@ -8,21 +8,17 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass -from typing import TYPE_CHECKING - from typing_extensions import Self from temporalio.api.common.v1 import Payload from temporalio.converter import ( JSONPlainPayloadConverter, + PayloadCodec, SerializationContext, WithSerializationContext, ) from temporalio.exceptions import TemporalError -if TYPE_CHECKING: - from temporalio.converter import PayloadCodec - @dataclass(frozen=True) class DriverClaim: @@ -123,37 +119,6 @@ def select_driver(self, context: DriverContext, payload: Payload) -> Driver | No pass -@dataclass(frozen=True) -class StorageConverter(WithSerializationContext): - """Converters for converting and encoding external payloads to/from Python values. - - .. warning:: - This API is experimental. - """ - - payload_codec: PayloadCodec | None - """Optional codec applied to payloads before they are handed to a - :class:`Driver` for storage, and after they are retrieved. When ``None``, - payloads are stored as-is by the driver. - """ - - def with_context(self, context: SerializationContext) -> Self: - """Return a copy of this converter with the serialization context applied. - - If :attr:`payload_codec` implements :class:`WithSerializationContext`, - a new instance is created with the context propagated to it. If nothing - changed, ``self`` is returned unchanged. - """ - payload_codec = self.payload_codec - if isinstance(payload_codec, WithSerializationContext): - payload_codec = payload_codec.with_context(context) - if payload_codec == self.payload_codec: - return self - cloned = dataclasses.replace(self) - object.__setattr__(cloned, "payload_codec", payload_codec) - return cloned - - @dataclass(frozen=True) class StorageOptions(WithSerializationContext): """Configuration for external storage behavior. @@ -192,20 +157,17 @@ class StorageOptions(WithSerializationContext): external storage regardless of size. """ - external_converter: StorageConverter | None = None - """Converter applied to payload bytes before they are passed to a driver - for storage, and after they are retrieved. When ``None``, payload bytes are - handed to the driver without any additional encoding. Note that the - ``DataConverter``'s ``payload_codec`` is applied to the reference payload - that replaces the original in workflow history, not to the externally stored - bytes themselves. + payload_codec: PayloadCodec | None = None + """Optional codec applied to payloads before they are handed to a + :class:`Driver` for storage, and after they are retrieved. When ``None``, + payloads are stored as-is by the driver. """ def with_context(self, context: SerializationContext) -> Self: """Return a copy of these options with the serialization context applied. Propagates *context* to any drivers, the driver selector, and the - external converter that implement :class:`WithSerializationContext`. + payload codec that implement :class:`WithSerializationContext`. If none of those fields changed, ``self`` is returned unchanged. """ drivers = list(self.drivers) @@ -215,22 +177,22 @@ def with_context(self, context: SerializationContext) -> Self: driver_selector = self.driver_selector if isinstance(driver_selector, WithSerializationContext): driver_selector = driver_selector.with_context(context) - external_converter = self.external_converter - if isinstance(external_converter, WithSerializationContext): - external_converter = external_converter.with_context(context) + payload_codec = self.payload_codec + if isinstance(payload_codec, WithSerializationContext): + payload_codec = payload_codec.with_context(context) if all( new is orig for new, orig in [ (drivers, self.drivers), (driver_selector, self.driver_selector), - (external_converter, self.external_converter), + (payload_codec, self.payload_codec), ] ): return self cloned = dataclasses.replace(self) object.__setattr__(cloned, "drivers", drivers) object.__setattr__(cloned, "driver_selector", driver_selector) - object.__setattr__(cloned, "external_converter", external_converter) + object.__setattr__(cloned, "payload_codec", payload_codec) return cloned @@ -332,15 +294,9 @@ def __init__( self, options: StorageOptions | None, context: SerializationContext | None = None, - payload_codec: PayloadCodec | None = None, ): self._options = options self._context = context - self._payload_codec = ( - options.external_converter.payload_codec - if options and options.external_converter - else payload_codec - ) self._driver_map: dict[str, Driver] = {} if options is not None: for driver in options.drivers: @@ -397,8 +353,8 @@ async def store_payload(self, payload: Payload) -> Payload: # Optionally encode the payload before externally storing it encoded_payload = payload - if self._payload_codec: - encoded_payload = (await self._payload_codec.encode([payload]))[0] + if self._options.payload_codec: + encoded_payload = (await self._options.payload_codec.encode([payload]))[0] try: claims = await driver.store(context, [encoded_payload]) @@ -452,8 +408,8 @@ async def store_payloads( # Optionally encode all payloads destined for external storage payloads_to_encode = [payload for _, payload, _ in to_store] encoded_payloads = payloads_to_encode - if self._payload_codec: - encoded_payloads = await self._payload_codec.encode(payloads_to_encode) + if self._options.payload_codec: + encoded_payloads = await self._options.payload_codec.encode(payloads_to_encode) # Group encoded payloads by driver for batched store calls # driver -> [(original_index, encoded_payload)] @@ -539,8 +495,8 @@ async def retrieve_payload( self._validate_payload_length(stored_payloads, expected=1, driver=driver) - if self._payload_codec: - stored_payloads = await self._payload_codec.decode(stored_payloads) + if self._options.payload_codec: + stored_payloads = await self._options.payload_codec.decode(stored_payloads) return stored_payloads[0] @@ -629,8 +585,8 @@ async def _retrieve_group( stored_list = [stored_by_index[idx] for idx in retrieve_indices] decoded_payloads = stored_list - if self._payload_codec: - decoded_payloads = await self._payload_codec.decode(stored_list) + if self._options.payload_codec: + decoded_payloads = await self._options.payload_codec.decode(stored_list) for i, retrieved_payload in enumerate(decoded_payloads): results[retrieve_indices[i]] = retrieved_payload diff --git a/tests/test_extstore.py b/tests/test_extstore.py index b80c2aa45..7f8d2df24 100644 --- a/tests/test_extstore.py +++ b/tests/test_extstore.py @@ -25,7 +25,6 @@ DriverNotFoundError, DriverSelector, PayloadNotFoundError, - StorageConverter, StorageOptions, StorageWarning, _StorageReference, @@ -537,8 +536,8 @@ async def test_dc_payload_codec_encodes_reference_payload(self): assert driver._retrieve_calls == 1 async def test_external_converter_without_codec_does_not_encode_stored_bytes(self): - """When DataConverter.payload_codec is set but StorageOptions.external_converter - has no payload_codec, stored bytes are NOT encoded – even though + """When DataConverter.payload_codec is set but StorageOptions.payload_codec + is None, stored bytes are NOT encoded – even though DataConverter.payload_codec is active for the reference payload in history.""" driver = InMemoryTestDriver() dc_codec = RecordingPayloadCodec("binary/dc-encoded") @@ -548,9 +547,6 @@ async def test_external_converter_without_codec_does_not_encode_stored_bytes(sel external_storage=StorageOptions( drivers=[driver], payload_size_threshold=50, - # Explicitly set external_converter without its own codec. - # DataConverter.payload_codec must NOT bleed through to stored bytes. - external_converter=StorageConverter(payload_codec=None), ), ) @@ -576,10 +572,10 @@ async def test_external_converter_without_codec_does_not_encode_stored_bytes(sel assert driver._retrieve_calls == 1 async def test_external_converter_codec_independent_from_dc_codec(self): - """When both DataConverter.payload_codec and - StorageOptions.external_converter.payload_codec are set, the reference - payload in history uses DataConverter.payload_codec and the bytes stored - by the driver use external_converter.payload_codec – independently.""" + """When both DataConverter.payload_codec and StorageOptions.payload_codec + are set, the reference payload in history uses DataConverter.payload_codec + and the bytes stored by the driver use StorageOptions.payload_codec – + independently.""" driver = InMemoryTestDriver() dc_codec = RecordingPayloadCodec("binary/dc-encoded") ext_codec = RecordingPayloadCodec("binary/ext-encoded") @@ -589,7 +585,7 @@ async def test_external_converter_codec_independent_from_dc_codec(self): external_storage=StorageOptions( drivers=[driver], payload_size_threshold=50, - external_converter=StorageConverter(payload_codec=ext_codec), + payload_codec=ext_codec, ), ) From 8c7b418816f49fe3ef6e785e4bc62a24ea4a5b36 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:21:40 -0800 Subject: [PATCH 03/12] Replace DriverError and DriverNotFoundError with RuntimeError --- README.md | 2 +- temporalio/converter.py | 28 +-------------- temporalio/extstore.py | 76 +++++++++++++---------------------------- tests/test_extstore.py | 50 +++++++++++++-------------- 4 files changed, 50 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index 4daa87ad3..82ff53f64 100644 --- a/README.md +++ b/README.md @@ -524,7 +524,7 @@ Some things to note about driver selection: * When no `driver_selector` is set, the first driver in `Options.drivers` is always used for storing. * Returning `None` from a selector leaves the payload stored inline in workflow history rather than offloading it. -* The driver returned by the selector must be registered in `Options.drivers`. If it is not, a `DriverNotFoundError` is raised. +* The driver returned by the selector must be registered in `Options.drivers`. If it is not, a `RuntimeError` is raised. ###### Custom Drivers diff --git a/temporalio/converter.py b/temporalio/converter.py index b67b61873..6ac5501fd 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -929,7 +929,6 @@ def to_failure( ) -> None: """See base class.""" from temporalio.extstore import ( - DriverError, PayloadNotFoundError, ) @@ -938,32 +937,6 @@ def to_failure( self._error_to_failure(exception, payload_converter, failure) elif isinstance(exception, nexusrpc.HandlerError): self._nexus_handler_error_to_failure(exception, payload_converter, failure) - elif isinstance(exception, PayloadNotFoundError): - # Convert to failure error - failure_error = temporalio.exceptions.ApplicationError( - str(exception), - { - "driver_name": exception.driver_name, - "driver_claim": exception.driver_claim, - }, - type=exception.__class__.__name__, - non_retryable=True, - ) - failure_error.__traceback__ = exception.__traceback__ - failure_error.__cause__ = exception.__cause__ - self._error_to_failure(failure_error, payload_converter, failure) - elif isinstance(exception, DriverError): - # Convert to failure error - failure_error = temporalio.exceptions.ApplicationError( - str(exception), - { - "driver_name": exception.driver_name, - }, - type=exception.__class__.__name__, - ) - failure_error.__traceback__ = exception.__traceback__ - failure_error.__cause__ = exception.__cause__ - self._error_to_failure(failure_error, payload_converter, failure) else: # Convert to failure error failure_error = temporalio.exceptions.ApplicationError( @@ -971,6 +944,7 @@ def to_failure( type="PayloadSizeError" if isinstance(exception, _PayloadSizeError) else exception.__class__.__name__, + non_retryable=isinstance(exception, PayloadNotFoundError), ) failure_error.__traceback__ = exception.__traceback__ failure_error.__cause__ = exception.__cause__ diff --git a/temporalio/extstore.py b/temporalio/extstore.py index af6fafe6f..ccf18c5ce 100644 --- a/temporalio/extstore.py +++ b/temporalio/extstore.py @@ -8,6 +8,7 @@ from abc import ABC, abstractmethod from collections.abc import Callable, Sequence from dataclasses import dataclass + from typing_extensions import Self from temporalio.api.common.v1 import Payload @@ -196,42 +197,6 @@ def with_context(self, context: SerializationContext) -> Self: return cloned -class DriverError(TemporalError): - """Raised when an error occurs related to a specific driver. - - .. warning:: - This API is experimental. - """ - - def __init__(self, message: str, driver_name: str) -> None: - """Initialize with an error message and the name of the driver that failed.""" - super().__init__(message) - self._driver_name = driver_name - - @property - def driver_name(self) -> str: - """Name of the driver that caused this error.""" - return self._driver_name - - -class DriverNotFoundError(DriverError): - """Raised when a driver name cannot be resolved to a driver in - :attr:`StorageOptions.drivers`. This can occur during retrieval when a - :class:`DriverClaim` references a driver name that is not present, or - during storage when the :attr:`StorageOptions.driver_selector` returns a - :class:`Driver` whose :meth:`Driver.name` is not registered. - - .. warning:: - This API is experimental. - """ - - def __init__(self, driver_name: str) -> None: - """Initialize with the name of the driver that could not be resolved.""" - super().__init__( - f"No driver found with name '{driver_name}'", driver_name=driver_name - ) - - class PayloadNotFoundError(TemporalError): """Raised when a payload cannot be retrieved because it does not exist at the location indicated by its :class:`DriverClaim`. @@ -240,9 +205,6 @@ class PayloadNotFoundError(TemporalError): rather than the workflow task. Drivers should raise this when a retrieval attempt confirms the payload is absent. - This error is intentionally not a subclass of :class:`DriverError` to - avoid accidentally handling it and treating as a workflow task failure. - .. warning:: This API is experimental. """ @@ -324,14 +286,14 @@ def _select_driver(self, context: DriverContext, payload: Payload) -> Driver | N return None registered = self._driver_map.get(driver.name()) if registered is None: - raise DriverNotFoundError(driver.name()) + raise RuntimeError(f"No driver found with name '{driver.name()}'") return registered def _get_driver_by_name(self, name: str) -> Driver: - """Looks up a driver by name, raising :class:`DriverNotFoundError` if not found.""" + """Looks up a driver by name, raising :class:`RuntimeError` if not found.""" driver = self._driver_map.get(name) if driver is None: - raise DriverNotFoundError(name) + raise RuntimeError(f"No driver found with name '{name}'") return driver async def store_payload(self, payload: Payload) -> Payload: @@ -359,7 +321,9 @@ async def store_payload(self, payload: Payload) -> Payload: try: claims = await driver.store(context, [encoded_payload]) except Exception as err: - raise DriverError("Driver store failed", driver.name()) from err + raise RuntimeError( + f"Driver store failed for driver '{driver.name()}'" + ) from err self._validate_claim_length(claims, expected=1, driver=driver) @@ -409,7 +373,9 @@ async def store_payloads( payloads_to_encode = [payload for _, payload, _ in to_store] encoded_payloads = payloads_to_encode if self._options.payload_codec: - encoded_payloads = await self._options.payload_codec.encode(payloads_to_encode) + encoded_payloads = await self._options.payload_codec.encode( + payloads_to_encode + ) # Group encoded payloads by driver for batched store calls # driver -> [(original_index, encoded_payload)] @@ -429,7 +395,9 @@ async def _store_group( try: return await driver.store(context, store_batch) except Exception as err: - raise DriverError("Driver store failed", driver.name()) from err + raise RuntimeError( + f"Driver store failed for driver '{driver.name()}'" + ) from err all_claims = await asyncio.gather( *( @@ -491,7 +459,9 @@ async def retrieve_payload( except PayloadNotFoundError: raise except Exception as err: - raise DriverError("Driver retrieve failed", driver.name()) from err + raise RuntimeError( + f"Driver retrieve failed for driver '{driver.name()}'" + ) from err self._validate_payload_length(stored_payloads, expected=1, driver=driver) @@ -557,7 +527,9 @@ async def _retrieve_group( except PayloadNotFoundError: raise except Exception as err: - raise DriverError("Driver retrieve failed", driver.name()) from err + raise RuntimeError( + f"Driver retrieve failed for driver '{driver.name()}'" + ) from err all_stored = await asyncio.gather( *( @@ -597,16 +569,14 @@ def _validate_claim_length( self, claims: Sequence[DriverClaim], expected: int, driver: Driver ) -> None: if len(claims) != expected: - raise DriverError( - f"Driver returned {len(claims)} claims, expected {expected}", - driver.name(), + raise RuntimeError( + f"Driver '{driver.name()}' returned {len(claims)} claims, expected {expected}", ) def _validate_payload_length( self, payloads: Sequence[Payload], expected: int, driver: Driver ) -> None: if len(payloads) != expected: - raise DriverError( - f"Driver returned {len(payloads)} payloads, expected {expected}", - driver.name(), + raise RuntimeError( + f"Driver '{driver.name()}' returned {len(payloads)} payloads, expected {expected}", ) diff --git a/tests/test_extstore.py b/tests/test_extstore.py index 7f8d2df24..31e125c21 100644 --- a/tests/test_extstore.py +++ b/tests/test_extstore.py @@ -21,8 +21,6 @@ Driver, DriverClaim, DriverContext, - DriverError, - DriverNotFoundError, DriverSelector, PayloadNotFoundError, StorageOptions, @@ -330,7 +328,6 @@ def test_class_hierarchy(self): assert issubclass(PayloadNotFoundError, TemporalError) assert not issubclass(PayloadNotFoundError, ApplicationError) assert not issubclass(PayloadNotFoundError, FailureError) - assert not issubclass(PayloadNotFoundError, DriverError) def test_default_message(self): claim = DriverClaim(data={"key": "my-key"}) @@ -344,7 +341,6 @@ def test_properties(self): assert err.driver_name == "my-driver" async def test_middleware_propagates_not_found(self): - """PayloadNotFoundError from a driver must not be wrapped in DriverError.""" converter = DataConverter( external_storage=StorageOptions( drivers=[NotFoundDriver()], @@ -356,16 +352,15 @@ async def test_middleware_propagates_not_found(self): encoded = await converter.encode(["hello world " * 20]) assert len(encoded[0].external_payloads) > 0 - # Retrieving should raise PayloadNotFoundError, not DriverError with pytest.raises(PayloadNotFoundError): await converter.decode(encoded, [str]) class TestDriverError: - """Tests for DriverError raised when a driver violates its contract.""" + """Tests for RuntimeError raised when a driver violates its contract.""" - async def test_encode_wrong_claim_count_raises_driver_error(self): - """store() returning fewer claims than payloads must raise DriverError.""" + async def test_encode_wrong_claim_count_raises_runtime_error(self): + """store() returning fewer claims than payloads must raise RuntimeError.""" class _NoClaimsDriver(InMemoryTestDriver): async def store( @@ -373,17 +368,21 @@ async def store( ) -> list[DriverClaim]: return [] + driver = _NoClaimsDriver() converter = DataConverter( external_storage=StorageOptions( - drivers=[_NoClaimsDriver()], + drivers=[driver], payload_size_threshold=10, ) ) - with pytest.raises(DriverError, match="Driver returned 0 claims, expected 1"): + with pytest.raises( + RuntimeError, + match=f"Driver '{driver.name()}' returned 0 claims, expected 1", + ): await converter.encode(["x" * 200]) - async def test_decode_wrong_payload_count_raises_driver_error(self): - """retrieve() returning fewer payloads than claims must raise DriverError.""" + async def test_decode_wrong_payload_count_raises_runtime_error(self): + """retrieve() returning fewer payloads than claims must raise RuntimeError.""" good_converter = DataConverter( external_storage=StorageOptions( drivers=[InMemoryTestDriver()], @@ -398,19 +397,21 @@ async def retrieve( ) -> list[Payload]: return [] + driver = _NoPayloadsDriver() bad_converter = DataConverter( external_storage=StorageOptions( - drivers=[ - _NoPayloadsDriver() - ], # same default name as InMemoryTestDriver + drivers=[driver], payload_size_threshold=10, ) ) - with pytest.raises(DriverError, match="Driver returned 0 payloads, expected 1"): + with pytest.raises( + RuntimeError, + match=f"Driver '{driver.name()}' returned 0 payloads, expected 1", + ): await bad_converter.decode(encoded, [str]) - async def test_encode_driver_exception_wrapped_in_driver_error(self): - """Exception raised by store() must be wrapped in DriverError.""" + async def test_encode_driver_exception_wrapped_in_runtime_error(self): + """Exception raised by store() must be wrapped in RuntimeError.""" class _StoreError(Exception): pass @@ -427,12 +428,12 @@ async def store( payload_size_threshold=10, ) ) - with pytest.raises(DriverError) as exc_info: + with pytest.raises(RuntimeError) as exc_info: await converter.encode(["x" * 200]) assert isinstance(exc_info.value.__cause__, _StoreError) - async def test_decode_driver_exception_wrapped_in_driver_error(self): - """Exception raised by retrieve() must be wrapped in DriverError.""" + async def test_decode_driver_exception_wrapped_in_runtime_error(self): + """Exception raised by retrieve() must be wrapped in RuntimeError.""" class _RetrieveError(Exception): pass @@ -459,7 +460,7 @@ async def retrieve( payload_size_threshold=10, ) ) - with pytest.raises(DriverError) as exc_info: + with pytest.raises(RuntimeError) as exc_info: await bad_converter.decode(encoded, [str]) assert isinstance(exc_info.value.__cause__, _RetrieveError) @@ -742,7 +743,7 @@ async def test_selector_returning_none_keeps_payload_inline(self): async def test_selector_returns_unregistered_driver_raises(self): """A selector that returns a Driver whose name is not present in - StorageOptions.drivers raises DriverNotFoundError during encode.""" + StorageOptions.drivers raises RuntimeError during encode.""" registered = InMemoryTestDriver("registered") unregistered = InMemoryTestDriver("not-in-list") @@ -754,9 +755,8 @@ async def test_selector_returns_unregistered_driver_raises(self): ) ) - with pytest.raises(DriverNotFoundError) as exc_info: + with pytest.raises(RuntimeError): await converter.encode(["x" * 200]) - assert exc_info.value.driver_name == "not-in-list" async def test_duplicate_driver_names_warns_and_first_wins_for_retrieval(self): """Registering two drivers with identical names emits StorageWarning. From e4883d1a22ab6db35a23e0a388ac0e8453f6d7be Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 11:33:03 -0800 Subject: [PATCH 04/12] Import exstore module instead of types --- temporalio/converter.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 6ac5501fd..293eaa724 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -46,7 +46,7 @@ import temporalio.types if TYPE_CHECKING: - from temporalio.extstore import StorageOptions, _ExternalStorageMiddleware + import temporalio.extstore # avoid circular import at runtime if sys.version_info < (3, 11): # Python's datetime.fromisoformat doesn't support certain formats pre-3.11 @@ -928,9 +928,6 @@ def to_failure( failure: temporalio.api.failure.v1.Failure, ) -> None: """See base class.""" - from temporalio.extstore import ( - PayloadNotFoundError, - ) # If already a failure error, use that if isinstance(exception, temporalio.exceptions.FailureError): @@ -944,7 +941,9 @@ def to_failure( type="PayloadSizeError" if isinstance(exception, _PayloadSizeError) else exception.__class__.__name__, - non_retryable=isinstance(exception, PayloadNotFoundError), + non_retryable=isinstance( + exception, temporalio.extstore.PayloadNotFoundError + ), ) failure_error.__traceback__ = exception.__traceback__ failure_error.__cause__ = exception.__cause__ @@ -1368,7 +1367,7 @@ class DataConverter(WithSerializationContext): payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() """Settings for payload size limits.""" - external_storage: StorageOptions | None = None + external_storage: temporalio.extstore.StorageOptions | None = None """Options for external storage. If None, external storage is disabled. .. warning:: @@ -1378,8 +1377,8 @@ class DataConverter(WithSerializationContext): default: ClassVar[DataConverter] """Singleton default data converter.""" - _external_storage_middleware: "_ExternalStorageMiddleware" = dataclasses.field( - init=False + _external_storage_middleware: temporalio.extstore._ExternalStorageMiddleware = ( + dataclasses.field(init=False) ) _payload_error_limits: _ServerPayloadErrorLimits | None = None @@ -1496,13 +1495,14 @@ def with_context(self, context: SerializationContext) -> Self: def _reset_external_storage_middleware( self, context: SerializationContext | None = None ) -> None: - # Lazy import to avoid circular dependency - from temporalio.extstore import _ExternalStorageMiddleware + import temporalio.extstore # lazy import to avoid circular dependency object.__setattr__( self, "_external_storage_middleware", - _ExternalStorageMiddleware(self.external_storage, context), + temporalio.extstore._ExternalStorageMiddleware( + self.external_storage, context + ), ) def _with_payload_error_limits( From c39be13671950a3e8eff9602316d7cb6a53a35ac Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:12:13 -0800 Subject: [PATCH 05/12] Remove DriverSelector --- temporalio/extstore.py | 25 +------------------------ tests/test_extstore.py | 16 ++++++++++++---- 2 files changed, 13 insertions(+), 28 deletions(-) diff --git a/temporalio/extstore.py b/temporalio/extstore.py index ccf18c5ce..26f900748 100644 --- a/temporalio/extstore.py +++ b/temporalio/extstore.py @@ -101,25 +101,6 @@ async def retrieve( raise NotImplementedError -class DriverSelector(ABC): - """Determines which :class:`Driver` stores a given payload. - - Implement this class and set it as :attr:`StorageOptions.driver_selector` when you - need stateful or class-based selection logic. For simple cases a plain - callable ``(DriverContext, Payload) -> Driver | None`` can be used instead. - - .. warning:: - This API is experimental. - """ - - @abstractmethod - def select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: - """Returns the driver to use to externally store the payload, or None to decline to - externally store the payload. - """ - pass - - @dataclass(frozen=True) class StorageOptions(WithSerializationContext): """Configuration for external storage behavior. @@ -140,9 +121,7 @@ class StorageOptions(WithSerializationContext): retrieval, so each driver must have a unique name. """ - driver_selector: ( - DriverSelector | Callable[[DriverContext, Payload], Driver | None] | None - ) = None + driver_selector: Callable[[DriverContext, Payload], Driver | None] | None = None """Controls which driver stores a given payload. Accepts either a :class:`DriverSelector` instance or a callable of the form ``(DriverContext, Payload) -> Driver | None``. @@ -278,8 +257,6 @@ def _select_driver(self, context: DriverContext, payload: Payload) -> Driver | N selector = self._options.driver_selector if selector is None: return self._options.drivers[0] if self._options.drivers else None - elif isinstance(selector, DriverSelector): - driver = selector.select_driver(context, payload) else: driver = selector(context, payload) if driver is None: diff --git a/tests/test_extstore.py b/tests/test_extstore.py index 31e125c21..d4e09dc7f 100644 --- a/tests/test_extstore.py +++ b/tests/test_extstore.py @@ -21,7 +21,6 @@ Driver, DriverClaim, DriverContext, - DriverSelector, PayloadNotFoundError, StorageOptions, StorageWarning, @@ -80,14 +79,23 @@ def parse_claim( return [parse_claim(claim) for claim in claims] -class WorkflowIdFeatureFlagDriverSelector(DriverSelector, WithSerializationContext): - """Example selector that conditionally stores based on workflow ID feature flag.""" +class WorkflowIdFeatureFlagDriverSelector(WithSerializationContext): + """Example selector that conditionally stores based on workflow ID feature flag. + + This example shows how a callable can implement WithSerializationContext if it + needs to precompute data from the serialization context instead of doing it on + every payload selection call. + + The feature flag in this example is a simple check on the workflow ID length, but in + a real implementation this could be a call to a feature flag service or a lookup in a + configuration store. + """ def __init__(self, driver: Driver, enabled: bool = False): self._driver = driver self._enabled = enabled - def select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: + def __call__(self, _context: DriverContext, _payload: Payload) -> Driver | None: return self._driver if self._enabled else None def with_context(self, context: SerializationContext) -> Self: From 354888e87005c10a500a84c9e1de75d4f1bf4345 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:37:07 -0800 Subject: [PATCH 06/12] Remove PayloadNotFoundError in favor of non-retryable ApplicationError --- temporalio/converter.py | 3 - temporalio/extstore.py | 80 +----------------- temporalio/worker/_workflow.py | 6 +- temporalio/worker/_workflow_instance.py | 4 +- tests/test_extstore.py | 105 ++---------------------- tests/worker/test_extstore.py | 14 ++-- 6 files changed, 22 insertions(+), 190 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index 293eaa724..05673af4b 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -941,9 +941,6 @@ def to_failure( type="PayloadSizeError" if isinstance(exception, _PayloadSizeError) else exception.__class__.__name__, - non_retryable=isinstance( - exception, temporalio.extstore.PayloadNotFoundError - ), ) failure_error.__traceback__ = exception.__traceback__ failure_error.__cause__ = exception.__cause__ diff --git a/temporalio/extstore.py b/temporalio/extstore.py index 26f900748..c67cbb786 100644 --- a/temporalio/extstore.py +++ b/temporalio/extstore.py @@ -18,7 +18,6 @@ SerializationContext, WithSerializationContext, ) -from temporalio.exceptions import TemporalError @dataclass(frozen=True) @@ -176,41 +175,6 @@ def with_context(self, context: SerializationContext) -> Self: return cloned -class PayloadNotFoundError(TemporalError): - """Raised when a payload cannot be retrieved because it does not exist - at the location indicated by its :class:`DriverClaim`. - - When raised during workflow execution this error fails the **workflow** - rather than the workflow task. Drivers should raise this when a retrieval - attempt confirms the payload is absent. - - .. warning:: - This API is experimental. - """ - - def __init__( - self, - message: str | None = None, - *, - driver_claim: DriverClaim, - driver_name: str, - ) -> None: - """Initialize a payload not found error.""" - super().__init__(message or f"Payload not found for driver '{driver_name}'") - self._driver_claim = driver_claim - self._driver_name = driver_name - - @property - def driver_claim(self) -> DriverClaim: - """The :class:`DriverClaim` for the payload that could not be found.""" - return self._driver_claim - - @property - def driver_name(self) -> str: - """Name of the driver that reported the payload as not found.""" - return self._driver_name - - class StorageWarning(RuntimeWarning): """Warning for external storage issues. @@ -295,12 +259,7 @@ async def store_payload(self, payload: Payload) -> Payload: if self._options.payload_codec: encoded_payload = (await self._options.payload_codec.encode([payload]))[0] - try: - claims = await driver.store(context, [encoded_payload]) - except Exception as err: - raise RuntimeError( - f"Driver store failed for driver '{driver.name()}'" - ) from err + claims = await driver.store(context, [encoded_payload]) self._validate_claim_length(claims, expected=1, driver=driver) @@ -365,20 +324,9 @@ async def store_payloads( # Store all driver groups concurrently then build reference payloads driver_group_list = list(driver_groups.items()) - async def _store_group( - driver: Driver, indexed_payloads: list[tuple[int, Payload]] - ) -> list[DriverClaim]: - store_batch = [p for _, p in indexed_payloads] - try: - return await driver.store(context, store_batch) - except Exception as err: - raise RuntimeError( - f"Driver store failed for driver '{driver.name()}'" - ) from err - all_claims = await asyncio.gather( *( - _store_group(driver, indexed_payloads) + driver.store(context, [p for _, p in indexed_payloads]) for driver, indexed_payloads in driver_group_list ) ) @@ -431,14 +379,7 @@ async def retrieve_payload( driver = self._get_driver_by_name(reference.driver_name) context = DriverContext(serialization_context=self._context) - try: - stored_payloads = await driver.retrieve(context, [reference.driver_claim]) - except PayloadNotFoundError: - raise - except Exception as err: - raise RuntimeError( - f"Driver retrieve failed for driver '{driver.name()}'" - ) from err + stored_payloads = await driver.retrieve(context, [reference.driver_claim]) self._validate_payload_length(stored_payloads, expected=1, driver=driver) @@ -495,22 +436,9 @@ async def retrieve_payloads( # Retrieve from all drivers concurrently driver_claim_list = list(driver_claims.items()) - async def _retrieve_group( - driver: Driver, indexed_claims: list[tuple[int, DriverClaim]] - ) -> list[Payload]: - claims_to_retrieve = [claim for _, claim in indexed_claims] - try: - return await driver.retrieve(context, claims_to_retrieve) - except PayloadNotFoundError: - raise - except Exception as err: - raise RuntimeError( - f"Driver retrieve failed for driver '{driver.name()}'" - ) from err - all_stored = await asyncio.gather( *( - _retrieve_group(driver, indexed_claims) + driver.retrieve(context, [claim for _, claim in indexed_claims]) for driver, indexed_claims in driver_claim_list ) ) diff --git a/temporalio/worker/_workflow.py b/temporalio/worker/_workflow.py index efafcb8e6..6fff2e8b5 100644 --- a/temporalio/worker/_workflow.py +++ b/temporalio/worker/_workflow.py @@ -25,7 +25,6 @@ import temporalio.workflow from temporalio.api.enums.v1 import WorkflowTaskFailedCause from temporalio.bridge.worker import PollShutdownError -from temporalio.extstore import PayloadNotFoundError from . import _command_aware_visitor from ._interceptor import ( @@ -341,7 +340,10 @@ async def _handle_activation( "Failed handling activation on workflow with run ID %s", act.run_id ) - if isinstance(err, PayloadNotFoundError): + if ( + isinstance(err, temporalio.exceptions.ApplicationError) + and err.non_retryable + ): # Fail the workflow execution terminally rather than failing the task command = completion.successful.commands.add() failure = command.fail_workflow_execution.failure diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 81ebca38c..9b02825ba 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -1783,8 +1783,7 @@ def workflow_set_current_details(self, details: str): def workflow_is_failure_exception(self, err: BaseException) -> bool: # An exception causes the workflow to fail (rather than the task) if it - # is already a failure error, a timeout error, a PayloadNotFoundError - # (unrecoverable missing external payload), or an instance of any of the + # is already a failure error, a timeout error, or an instance of any of the # failure exception types configured at the worker or workflow level. wf_failure_exception_types = self._defn.failure_exception_types if self._dynamic_failure_exception_types is not None: @@ -1792,7 +1791,6 @@ def workflow_is_failure_exception(self, err: BaseException) -> bool: return ( isinstance(err, temporalio.exceptions.FailureError) or isinstance(err, asyncio.TimeoutError) - or isinstance(err, temporalio.extstore.PayloadNotFoundError) or any(isinstance(err, typ) for typ in wf_failure_exception_types) or any( isinstance(err, typ) diff --git a/tests/test_extstore.py b/tests/test_extstore.py index d4e09dc7f..e47c8acab 100644 --- a/tests/test_extstore.py +++ b/tests/test_extstore.py @@ -16,12 +16,11 @@ WithSerializationContext, WorkflowSerializationContext, ) -from temporalio.exceptions import ApplicationError, FailureError, TemporalError +from temporalio.exceptions import ApplicationError from temporalio.extstore import ( Driver, DriverClaim, DriverContext, - PayloadNotFoundError, StorageOptions, StorageWarning, _StorageReference, @@ -71,7 +70,9 @@ def parse_claim( ) -> Payload: key = claim.data["key"] if key not in self._storage: - raise PayloadNotFoundError(driver_claim=claim, driver_name=self.name()) + raise ApplicationError( + f"Payload not found for key '{key}'", non_retryable=True + ) payload = Payload() payload.ParseFromString(self._storage[key]) return payload @@ -294,7 +295,7 @@ async def test_extstore_serialization_context(self): class NotFoundDriver(Driver): - """Driver that stores normally but raises PayloadNotFoundError on retrieve.""" + """Driver that stores normally but raises non-retryable ApplicationError on retrieve.""" def __init__(self, driver_name: str = "not-found-driver"): self._driver_name = driver_name @@ -321,47 +322,7 @@ async def retrieve( claims: Sequence[DriverClaim], ) -> list[Payload]: assert len(claims) > 0, "NotFoundDriver expected claims to be provided" - raise PayloadNotFoundError( - "Payload not found in not-found-driver", - driver_claim=claims[0], - driver_name=self.name(), - ) - - -class TestPayloadNotFoundError: - """Tests for PayloadNotFoundError class and middleware behaviour.""" - - def test_class_hierarchy(self): - """PayloadNotFoundError must be TemporalError but not ApplicationError or FailureError.""" - assert issubclass(PayloadNotFoundError, TemporalError) - assert not issubclass(PayloadNotFoundError, ApplicationError) - assert not issubclass(PayloadNotFoundError, FailureError) - - def test_default_message(self): - claim = DriverClaim(data={"key": "my-key"}) - err = PayloadNotFoundError(driver_claim=claim, driver_name="my-driver") - assert str(err) == "Payload not found for driver 'my-driver'" - - def test_properties(self): - claim = DriverClaim(data={"key": "my-key"}) - err = PayloadNotFoundError("gone", driver_claim=claim, driver_name="my-driver") - assert err.driver_claim is claim - assert err.driver_name == "my-driver" - - async def test_middleware_propagates_not_found(self): - converter = DataConverter( - external_storage=StorageOptions( - drivers=[NotFoundDriver()], - payload_size_threshold=1, # store everything - ) - ) - - # Store a payload so we have a reference to retrieve - encoded = await converter.encode(["hello world " * 20]) - assert len(encoded[0].external_payloads) > 0 - - with pytest.raises(PayloadNotFoundError): - await converter.decode(encoded, [str]) + raise ApplicationError("Payload not found.", non_retryable=True) class TestDriverError: @@ -418,60 +379,6 @@ async def retrieve( ): await bad_converter.decode(encoded, [str]) - async def test_encode_driver_exception_wrapped_in_runtime_error(self): - """Exception raised by store() must be wrapped in RuntimeError.""" - - class _StoreError(Exception): - pass - - class _RaisingStoreDriver(InMemoryTestDriver): - async def store( - self, context: DriverContext, payloads: Sequence[Payload] - ) -> list[DriverClaim]: - raise _StoreError("store failed") - - converter = DataConverter( - external_storage=StorageOptions( - drivers=[_RaisingStoreDriver()], - payload_size_threshold=10, - ) - ) - with pytest.raises(RuntimeError) as exc_info: - await converter.encode(["x" * 200]) - assert isinstance(exc_info.value.__cause__, _StoreError) - - async def test_decode_driver_exception_wrapped_in_runtime_error(self): - """Exception raised by retrieve() must be wrapped in RuntimeError.""" - - class _RetrieveError(Exception): - pass - - class _RaisingRetrieveDriver(InMemoryTestDriver): - async def retrieve( - self, context: DriverContext, claims: Sequence[DriverClaim] - ) -> list[Payload]: - raise _RetrieveError("retrieve failed") - - good_converter = DataConverter( - external_storage=StorageOptions( - drivers=[InMemoryTestDriver()], - payload_size_threshold=10, - ) - ) - encoded = await good_converter.encode(["x" * 200]) - - bad_converter = DataConverter( - external_storage=StorageOptions( - drivers=[ - _RaisingRetrieveDriver() - ], # same default name as InMemoryTestDriver - payload_size_threshold=10, - ) - ) - with pytest.raises(RuntimeError) as exc_info: - await bad_converter.decode(encoded, [str]) - assert isinstance(exc_info.value.__cause__, _RetrieveError) - class RecordingPayloadCodec(PayloadCodec): """Codec that wraps each payload under a recognisable ``encoding`` label. diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index a53d1334f..4da9b73cb 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -16,7 +16,6 @@ from temporalio.extstore import ( DriverClaim, DriverContext, - PayloadNotFoundError, StorageOptions, StorageWarning, ) @@ -100,9 +99,10 @@ async def retrieve( if self._no_retrieve: return [] if self._raise_payload_not_found: - raise PayloadNotFoundError( - driver_claim=claims[0], - driver_name=self.name(), + raise ApplicationError( + "Payload not found", + type="PayloadNotFoundError", + non_retryable=True, ) return await super().retrieve(context, claims) @@ -236,7 +236,7 @@ async def test_extstore_worker_missing_driver( async def test_extstore_payload_not_found_fails_workflow( env: WorkflowEnvironment, ): - """When a PayloadNotFoundError is raised while retrieving workflow input, + """When a non-retryable ApplicationError is raised while retrieving workflow input, the workflow must fail terminally (not retry as a task failure). """ client = await Client.connect( @@ -388,8 +388,8 @@ async def test_replay_extstore_history_fails_with_empty_driver( ), ), ).replay_workflow(history, raise_on_replay_failure=False) - # InMemoryTestDriver raises PayloadNotFoundError for absent keys. - # PayloadNotFoundError is re-raised without wrapping, so it propagates + # InMemoryTestDriver raises ApplicationError for absent keys. + # ApplicationError is re-raised without wrapping, so it propagates # through decode_activation (before the workflow task runs). The core SDK # receives an activation failure, issues a FailWorkflow command, but the # next history event is ActivityTaskScheduled — causing a NondeterminismError. From 9da8504794ee563ce3f3faeedee44f5411386e88 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 12:57:16 -0800 Subject: [PATCH 07/12] Renames and doc updates --- README.md | 77 +++++++++++++++++----------- temporalio/converter.py | 10 ++-- temporalio/extstore.py | 70 ++++++++++++++------------ tests/test_extstore.py | 94 ++++++++++++++++++----------------- tests/worker/test_extstore.py | 28 +++++------ 5 files changed, 153 insertions(+), 126 deletions(-) diff --git a/README.md b/README.md index 82ff53f64..cfd302cf3 100644 --- a/README.md +++ b/README.md @@ -463,7 +463,7 @@ Now `IPv4Address` can be used in type hints including collections, optionals, et External payload storage allows large payloads to be offloaded to an external storage service (such as Amazon S3) rather than stored inline in workflow history. This is useful when workflows or activities work with data that would otherwise exceed Temporal's payload size limits. -External payload storage is configured via the `external_storage` parameter on `DataConverter`, which accepts a `temporalio.extstore.Options` instance. Any driver used to store payloads must also be configured on the component that retrieves them — for example, if the client stores workflow inputs using a driver, the worker must include that driver in its `Options.drivers` list to retrieve them. +External payload storage is configured via the `external_storage` parameter on `DataConverter`, which accepts a `temporalio.extstore.StorageConfig` instance. Any driver used to store payloads must also be configured on the component that retrieves them — for example, if the client stores workflow inputs using a driver, the worker must include that driver in its `StorageConfig.drivers` list to retrieve them. The simplest setup uses a single storage driver: @@ -471,7 +471,7 @@ The simplest setup uses a single storage driver: import dataclasses from temporalio.client import Client from temporalio.converter import DataConverter -from temporalio.extstore import Options +from temporalio.extstore import StorageConfig driver = MyDriver() @@ -479,26 +479,26 @@ client = await Client.connect( "localhost:7233", data_converter=dataclasses.replace( DataConverter.default, - external_storage=Options(drivers=[driver]), + external_storage=StorageConfig(drivers=[driver]), ), ) ``` Some things to note about external payload storage: -* Only payloads that meet or exceed `Options.payload_size_threshold` (default 256 KiB) are offloaded. Smaller payloads are stored inline as normal. +* Only payloads that meet or exceed `StorageConfig.payload_size_threshold` (default 256 KiB) are offloaded. Smaller payloads are stored inline as normal. * External payload storage applies transparently to workflow inputs/outputs, activity inputs/outputs, signals, updates, queries, and failure details. -* The `DataConverter`'s `payload_codec` (if configured) is applied to the *reference* payload stored in workflow history, not to the externally stored bytes. To encrypt or compress the bytes handed to a driver, use `Options.payload_codec`. -* Setting `Options.payload_size_threshold` to `None` causes every payload to be considered for external payload storage regardless of size. +* The `DataConverter`'s `payload_codec` (if configured) is applied to the *reference* payload stored in workflow history, not to the externally stored bytes. To encrypt or compress the bytes handed to a driver, use `StorageConfig.payload_codec`. +* Setting `StorageConfig.payload_size_threshold` to `None` causes every payload to be considered for external payload storage regardless of size. ###### Multiple Drivers and Driver Selection -When multiple storage backends are needed, list all drivers in `Options.drivers` and provide a `driver_selector` to control which driver stores new payloads. Any driver in the list not chosen for storing is still available for retrieval, which is useful when migrating between storage backends. +When multiple storage backends are needed, list all drivers in `StorageConfig.drivers` and provide a `driver_selector` to control which driver stores new payloads. Any driver in the list not chosen for storing is still available for retrieval, which is useful when migrating between storage backends. ```python -from temporalio.extstore import Options +from temporalio.extstore import StorageConfig -options = Options( +options = StorageConfig( drivers=[hot_driver, cold_driver], driver_selector=lambda context, payload: ( hot_driver if payload.ByteSize() < 5 * 1024 * 1024 else cold_driver @@ -506,36 +506,57 @@ options = Options( ) ``` -For stateful or class-based selection logic, implement `temporalio.extstore.DriverSelector`: +For stateful or class-based selection logic, implement a callable class. If it also implements `temporalio.converter.WithSerializationContext`, it will receive workflow or activity context (namespace, workflow ID, etc.) at serialization time, just like a driver or payload codec: ```python -from temporalio.extstore import Driver, DriverContext, DriverSelector +from typing_extensions import Self +import temporalio.converter +import temporalio.extstore from temporalio.api.common.v1 import Payload -class MyDriverSelector(DriverSelector): - def select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: - # Return None to store the payload inline rather than externally - if payload.ByteSize() < 256 * 1024: - return None - return hot_driver +class FeatureFlaggedDriverSelector(temporalio.converter.WithSerializationContext): + def __init__(self, driver: temporalio.extstore.StorageDriver, enabled: bool = False): + self._driver = driver + self._enabled = enabled + + def __call__( + self, _context: temporalio.extstore.StorageDriverContext, _payload: Payload + ) -> temporalio.extstore.StorageDriver | None: + return self._driver if self._enabled else None + + def with_context(self, context: temporalio.converter.SerializationContext) -> Self: + workflow_id = None + if isinstance(context, temporalio.converter.WorkflowSerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + elif isinstance(context, temporalio.converter.ActivitySerializationContext) and context.workflow_id: + workflow_id = context.workflow_id + + return FeatureFlaggedDriverSelector( + self._driver, FeatureFlaggedDriverSelector.feature_flag_is_on(workflow_id) + ) + + @staticmethod + def feature_flag_is_on(workflow_id: str | None) -> bool: + """Mock implementation of a feature flag based on a workflow ID.""" + return workflow_id is not None and len(workflow_id) % 2 == 0 ``` Some things to note about driver selection: -* When no `driver_selector` is set, the first driver in `Options.drivers` is always used for storing. +* When no `driver_selector` is set, the first driver in `StorageConfig.drivers` is always used for storing. * Returning `None` from a selector leaves the payload stored inline in workflow history rather than offloading it. -* The driver returned by the selector must be registered in `Options.drivers`. If it is not, a `RuntimeError` is raised. +* The driver returned by the selector must be registered in `StorageConfig.drivers`. If it is not, a `RuntimeError` is raised. ###### Custom Drivers -Implement `temporalio.extstore.Driver` to integrate with any external payload storage system: +Implement `temporalio.extstore.StorageDriver` to integrate with an external storage system: ```python from collections.abc import Sequence -from temporalio.extstore import Driver, DriverClaim, DriverContext +from temporalio.extstore import StorageDriver, StorageDriverClaim, StorageDriverContext from temporalio.api.common.v1 import Payload -class MyDriver(Driver): +class MyDriver(StorageDriver): def __init__(self, driver_name: str | None = None): self._driver_name = driver_name or "my-org:driver:my-driver" @@ -543,16 +564,16 @@ class MyDriver(Driver): return self._driver_name async def store( - self, context: DriverContext, payloads: Sequence[Payload] - ) -> list[DriverClaim]: + self, context: StorageDriverContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: claims = [] for payload in payloads: key = await my_storage.put(payload.SerializeToString()) - claims.append(DriverClaim(data={"key": key})) + claims.append(StorageDriverClaim(data={"key": key})) return claims async def retrieve( - self, context: DriverContext, claims: Sequence[DriverClaim] + self, context: StorageDriverContext, claims: Sequence[StorageDriverClaim] ) -> list[Payload]: payloads = [] for claim in claims: @@ -566,8 +587,8 @@ class MyDriver(Driver): Some things to note about implementing a custom driver: * `store` and `retrieve` must return lists of the same length as their respective input sequences. -* `Driver.name()` must return a string that is unique among all drivers in `Options.drivers`. This name is embedded in the reference payload stored in workflow history and used to look up the correct driver during retrieval — changing it after payloads have been stored will break retrieval. -* `Driver.type()` is automatically implemented to return the name of the class. This can be overriden in subclasses but must remain consistent across all instances of the subclass. +* `StorageDriver.name()` must return a string that is unique among all drivers in `StorageConfig.drivers`. This name is embedded in the reference payload stored in workflow history and used to look up the correct driver during retrieval — changing it after payloads have been stored will break retrieval. +* `StorageDriver.type()` is automatically implemented to return the name of the class. This can be overridden in subclasses but must remain consistent across all instances of the subclass. * Implement `temporalio.converter.WithSerializationContext` on your driver to receive workflow or activity context (namespace, workflow ID, activity ID, etc.) at serialization time. ### Workers diff --git a/temporalio/converter.py b/temporalio/converter.py index 05673af4b..f88e4205d 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1364,7 +1364,7 @@ class DataConverter(WithSerializationContext): payload_limits: PayloadLimitsConfig = PayloadLimitsConfig() """Settings for payload size limits.""" - external_storage: temporalio.extstore.StorageOptions | None = None + external_storage: temporalio.extstore.StorageConfig | None = None """Options for external storage. If None, external storage is disabled. .. warning:: @@ -1374,8 +1374,8 @@ class DataConverter(WithSerializationContext): default: ClassVar[DataConverter] """Singleton default data converter.""" - _external_storage_middleware: temporalio.extstore._ExternalStorageMiddleware = ( - dataclasses.field(init=False) + _external_storage_middleware: temporalio.extstore._StorageImpl = dataclasses.field( + init=False ) _payload_error_limits: _ServerPayloadErrorLimits | None = None @@ -1497,9 +1497,7 @@ def _reset_external_storage_middleware( object.__setattr__( self, "_external_storage_middleware", - temporalio.extstore._ExternalStorageMiddleware( - self.external_storage, context - ), + temporalio.extstore._StorageImpl(self.external_storage, context), ) def _with_payload_error_limits( diff --git a/temporalio/extstore.py b/temporalio/extstore.py index c67cbb786..5198f234e 100644 --- a/temporalio/extstore.py +++ b/temporalio/extstore.py @@ -21,7 +21,7 @@ @dataclass(frozen=True) -class DriverClaim: +class StorageDriverClaim: """Claim for an externally stored payload. .. warning:: @@ -33,8 +33,8 @@ class DriverClaim: @dataclass(frozen=True) -class DriverContext: - """Context passed to :class:`Driver` and :class:`DriverSelector` calls. +class StorageDriverContext: + """Context passed to :class:`StorageDriver` and :class:`DriverSelector` calls. .. warning:: This API is experimental. @@ -46,7 +46,7 @@ class DriverContext: """ -class Driver(ABC): +class StorageDriver(ABC): """Base driver for storing and retrieve payloads from external storage systems. .. warning:: @@ -57,7 +57,7 @@ class Driver(ABC): def name(self) -> str: """Returns the name of this driver instance. A driver may allow its name to be parameterized at construction time so that multiple instances of - the same driver class can coexist in :attr:`StorageOptions.drivers` with + the same driver class can coexist in :attr:`StorageConfig.drivers` with distinct names. """ raise NotImplementedError @@ -75,10 +75,10 @@ def type(self) -> str: @abstractmethod async def store( self, - context: DriverContext, + context: StorageDriverContext, payloads: Sequence[Payload], - ) -> list[DriverClaim]: - """Stores payloads in external storage and returns a :class:`DriverClaim` + ) -> list[StorageDriverClaim]: + """Stores payloads in external storage and returns a :class:`StorageDriverClaim` for each one. The returned list must be the same length as ``payloads``. """ raise NotImplementedError @@ -86,10 +86,10 @@ async def store( @abstractmethod async def retrieve( self, - context: DriverContext, - claims: Sequence[DriverClaim], + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], ) -> list[Payload]: - """Retrieves payloads from external storage for the given :class:`DriverClaim` + """Retrieves payloads from external storage for the given :class:`StorageDriverClaim` list. The returned list must be the same length as ``claims``. Raise :class:`PayloadNotFoundError` when a retrieval attempt confirms @@ -101,14 +101,14 @@ async def retrieve( @dataclass(frozen=True) -class StorageOptions(WithSerializationContext): +class StorageConfig(WithSerializationContext): """Configuration for external storage behavior. .. warning:: This API is experimental. """ - drivers: Sequence[Driver] + drivers: Sequence[StorageDriver] """Drivers available for storing and retrieving payloads. At least one driver must be provided. @@ -120,7 +120,9 @@ class StorageOptions(WithSerializationContext): retrieval, so each driver must have a unique name. """ - driver_selector: Callable[[DriverContext, Payload], Driver | None] | None = None + driver_selector: ( + Callable[[StorageDriverContext, Payload], StorageDriver | None] | None + ) = None """Controls which driver stores a given payload. Accepts either a :class:`DriverSelector` instance or a callable of the form ``(DriverContext, Payload) -> Driver | None``. @@ -138,7 +140,7 @@ class StorageOptions(WithSerializationContext): payload_codec: PayloadCodec | None = None """Optional codec applied to payloads before they are handed to a - :class:`Driver` for storage, and after they are retrieved. When ``None``, + :class:`StorageDriver` for storage, and after they are retrieved. When ``None``, payloads are stored as-is by the driver. """ @@ -186,10 +188,10 @@ class StorageWarning(RuntimeWarning): @dataclass(frozen=True) class _StorageReference: driver_name: str - driver_claim: DriverClaim + driver_claim: StorageDriverClaim -class _ExternalStorageMiddleware: # type:ignore[reportUnusedClass] +class _StorageImpl: # type:ignore[reportUnusedClass] # Claim payload encoding is fixed and independent of any user configuration. _claim_converter: JSONPlainPayloadConverter = JSONPlainPayloadConverter( encoding="json/external-storage-reference" @@ -197,25 +199,27 @@ class _ExternalStorageMiddleware: # type:ignore[reportUnusedClass] def __init__( self, - options: StorageOptions | None, + options: StorageConfig | None, context: SerializationContext | None = None, ): self._options = options self._context = context - self._driver_map: dict[str, Driver] = {} + self._driver_map: dict[str, StorageDriver] = {} if options is not None: for driver in options.drivers: name = driver.name() if name in self._driver_map: warnings.warn( - f"StorageOptions.drivers contains multiple drivers with name '{name}'. " + f"StorageConfig.drivers contains multiple drivers with name '{name}'. " "The first one will be used.", category=StorageWarning, ) else: self._driver_map[name] = driver - def _select_driver(self, context: DriverContext, payload: Payload) -> Driver | None: + def _select_driver( + self, context: StorageDriverContext, payload: Payload + ) -> StorageDriver | None: """Returns the driver to use for this payload, or None to pass through.""" assert self._options is not None selector = self._options.driver_selector @@ -230,7 +234,7 @@ def _select_driver(self, context: DriverContext, payload: Payload) -> Driver | N raise RuntimeError(f"No driver found with name '{driver.name()}'") return registered - def _get_driver_by_name(self, name: str) -> Driver: + def _get_driver_by_name(self, name: str) -> StorageDriver: """Looks up a driver by name, raising :class:`RuntimeError` if not found.""" driver = self._driver_map.get(name) if driver is None: @@ -248,7 +252,7 @@ async def store_payload(self, payload: Payload) -> Payload: ): return payload - context = DriverContext(serialization_context=self._context) + context = StorageDriverContext(serialization_context=self._context) driver = self._select_driver(context, payload) if driver is None: @@ -285,11 +289,11 @@ async def store_payloads( return [await self.store_payload(payloads[0])] results = list(payloads) - context = DriverContext(serialization_context=self._context) + context = StorageDriverContext(serialization_context=self._context) # First pass: determine which payloads to store and which driver to use for each. # Provide unencoded payloads to give maximal context information to the selector. - to_store: list[tuple[int, Payload, Driver]] = [] + to_store: list[tuple[int, Payload, StorageDriver]] = [] for index, payload in enumerate(payloads): size_bytes = payload.ByteSize() if ( @@ -315,7 +319,7 @@ async def store_payloads( # Group encoded payloads by driver for batched store calls # driver -> [(original_index, encoded_payload)] - driver_groups: dict[Driver, list[tuple[int, Payload]]] = {} + driver_groups: dict[StorageDriver, list[tuple[int, Payload]]] = {} for i, (orig_index, _, driver) in enumerate(to_store): driver_groups.setdefault(driver, []).append( (orig_index, encoded_payloads[i]) @@ -364,7 +368,7 @@ async def retrieve_payload( ) elif len(self._options.drivers) == 0: warnings.warn( - "StorageOptions.drivers is empty, but detected external storage references.", + "StorageConfig.drivers is empty, but detected external storage references.", category=StorageWarning, ) return payload @@ -377,7 +381,7 @@ async def retrieve_payload( return payload driver = self._get_driver_by_name(reference.driver_name) - context = DriverContext(serialization_context=self._context) + context = StorageDriverContext(serialization_context=self._context) stored_payloads = await driver.retrieve(context, [reference.driver_claim]) @@ -405,7 +409,7 @@ async def retrieve_payloads( ) elif len(self._options.drivers) == 0: warnings.warn( - "StorageOptions.drivers is empty, but detected external storage references.", + "StorageConfig.drivers is empty, but detected external storage references.", category=StorageWarning, ) return results @@ -415,7 +419,7 @@ async def retrieve_payloads( # Group claims by driver for batched retrieve calls # driver -> [(original_index, claim)] - driver_claims: dict[Driver, list[tuple[int, DriverClaim]]] = {} + driver_claims: dict[StorageDriver, list[tuple[int, StorageDriverClaim]]] = {} for index, payload in enumerate(payloads): if len(payload.external_payloads) == 0: continue @@ -430,7 +434,7 @@ async def retrieve_payloads( if not driver_claims: return results - context = DriverContext(serialization_context=self._context) + context = StorageDriverContext(serialization_context=self._context) stored_by_index: dict[int, Payload] = {} # Retrieve from all drivers concurrently @@ -471,7 +475,7 @@ async def retrieve_payloads( return results def _validate_claim_length( - self, claims: Sequence[DriverClaim], expected: int, driver: Driver + self, claims: Sequence[StorageDriverClaim], expected: int, driver: StorageDriver ) -> None: if len(claims) != expected: raise RuntimeError( @@ -479,7 +483,7 @@ def _validate_claim_length( ) def _validate_payload_length( - self, payloads: Sequence[Payload], expected: int, driver: Driver + self, payloads: Sequence[Payload], expected: int, driver: StorageDriver ) -> None: if len(payloads) != expected: raise RuntimeError( diff --git a/tests/test_extstore.py b/tests/test_extstore.py index e47c8acab..c584d8b3b 100644 --- a/tests/test_extstore.py +++ b/tests/test_extstore.py @@ -18,16 +18,16 @@ ) from temporalio.exceptions import ApplicationError from temporalio.extstore import ( - Driver, - DriverClaim, - DriverContext, - StorageOptions, + StorageConfig, + StorageDriver, + StorageDriverClaim, + StorageDriverContext, StorageWarning, _StorageReference, ) -class InMemoryTestDriver(Driver): +class InMemoryTestDriver(StorageDriver): """In-memory storage driver for testing.""" def __init__( @@ -44,9 +44,9 @@ def name(self) -> str: async def store( self, - context: DriverContext, + context: StorageDriverContext, payloads: Sequence[Payload], - ) -> list[DriverClaim]: + ) -> list[StorageDriverClaim]: self._store_calls += 1 start_index = len(self._storage) @@ -56,17 +56,17 @@ async def store( ] self._storage.update(entries) - return [DriverClaim(data={"key": key}) for key, _ in entries] + return [StorageDriverClaim(data={"key": key}) for key, _ in entries] async def retrieve( self, - context: DriverContext, - claims: Sequence[DriverClaim], + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], ) -> list[Payload]: self._retrieve_calls += 1 def parse_claim( - claim: DriverClaim, + claim: StorageDriverClaim, ) -> Payload: key = claim.data["key"] if key not in self._storage: @@ -92,11 +92,13 @@ class WorkflowIdFeatureFlagDriverSelector(WithSerializationContext): configuration store. """ - def __init__(self, driver: Driver, enabled: bool = False): + def __init__(self, driver: StorageDriver, enabled: bool = False): self._driver = driver self._enabled = enabled - def __call__(self, _context: DriverContext, _payload: Payload) -> Driver | None: + def __call__( + self, _context: StorageDriverContext, _payload: Payload + ) -> StorageDriver | None: return self._driver if self._enabled else None def with_context(self, context: SerializationContext) -> Self: @@ -130,7 +132,7 @@ async def test_extstore_encode_decode(self): # Configure with 100-byte threshold converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=100, ) @@ -159,7 +161,7 @@ async def test_extstore_encode_decode(self): async def test_extstore_reference_structure(self): """Test that external storage creates proper reference structure.""" converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[InMemoryTestDriver("test-driver")], payload_size_threshold=50, ) @@ -182,7 +184,7 @@ async def test_extstore_reference_structure(self): assert isinstance(reference, _StorageReference) assert "test-driver" == reference.driver_name - assert isinstance(reference.driver_claim, DriverClaim) + assert isinstance(reference.driver_claim, StorageDriverClaim) assert "key" in reference.driver_claim.data async def test_extstore_composite_conditional(self): @@ -190,7 +192,7 @@ async def test_extstore_composite_conditional(self): hot_driver = InMemoryTestDriver("hot-storage") cold_driver = InMemoryTestDriver("cold-storage") - options = StorageOptions( + options = StorageConfig( drivers=[hot_driver, cold_driver], driver_selector=lambda context, payload: hot_driver if payload.ByteSize() < 500 @@ -236,7 +238,7 @@ async def test_extstore_serialization_context(self): # or if the workflow ID doesn't end with "-extstore". This is an example of feature flagging # external storage using the workflow ID. This is an advanced secnario and requires the "can_store" # filter to be a WithSerializationContext. - options = StorageOptions( + options = StorageConfig( drivers=[driver], driver_selector=WorkflowIdFeatureFlagDriverSelector(driver), payload_size_threshold=1024, @@ -294,7 +296,7 @@ async def test_extstore_serialization_context(self): assert driver._store_calls == 1 -class NotFoundDriver(Driver): +class NotFoundDriver(StorageDriver): """Driver that stores normally but raises non-retryable ApplicationError on retrieve.""" def __init__(self, driver_name: str = "not-found-driver"): @@ -306,20 +308,20 @@ def name(self) -> str: async def store( self, - context: DriverContext, + context: StorageDriverContext, payloads: Sequence[Payload], - ) -> list[DriverClaim]: + ) -> list[StorageDriverClaim]: entries = [ (f"payload-{i}", payload.SerializeToString()) for i, payload in enumerate(payloads) ] self._storage.update(entries) - return [DriverClaim(data={"key": key}) for key, _ in entries] + return [StorageDriverClaim(data={"key": key}) for key, _ in entries] async def retrieve( self, - context: DriverContext, - claims: Sequence[DriverClaim], + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], ) -> list[Payload]: assert len(claims) > 0, "NotFoundDriver expected claims to be provided" raise ApplicationError("Payload not found.", non_retryable=True) @@ -333,13 +335,13 @@ async def test_encode_wrong_claim_count_raises_runtime_error(self): class _NoClaimsDriver(InMemoryTestDriver): async def store( - self, context: DriverContext, payloads: Sequence[Payload] - ) -> list[DriverClaim]: + self, context: StorageDriverContext, payloads: Sequence[Payload] + ) -> list[StorageDriverClaim]: return [] driver = _NoClaimsDriver() converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=10, ) @@ -353,7 +355,7 @@ async def store( async def test_decode_wrong_payload_count_raises_runtime_error(self): """retrieve() returning fewer payloads than claims must raise RuntimeError.""" good_converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[InMemoryTestDriver()], payload_size_threshold=10, ) @@ -362,13 +364,15 @@ async def test_decode_wrong_payload_count_raises_runtime_error(self): class _NoPayloadsDriver(InMemoryTestDriver): async def retrieve( - self, context: DriverContext, claims: Sequence[DriverClaim] + self, + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], ) -> list[Payload]: return [] driver = _NoPayloadsDriver() bad_converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=10, ) @@ -424,7 +428,7 @@ async def test_dc_payload_codec_encodes_reference_payload(self): converter = DataConverter( payload_codec=dc_codec, - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=50, ), @@ -452,7 +456,7 @@ async def test_dc_payload_codec_encodes_reference_payload(self): assert driver._retrieve_calls == 1 async def test_external_converter_without_codec_does_not_encode_stored_bytes(self): - """When DataConverter.payload_codec is set but StorageOptions.payload_codec + """When DataConverter.payload_codec is set but StorageConfig.payload_codec is None, stored bytes are NOT encoded – even though DataConverter.payload_codec is active for the reference payload in history.""" driver = InMemoryTestDriver() @@ -460,7 +464,7 @@ async def test_external_converter_without_codec_does_not_encode_stored_bytes(sel converter = DataConverter( payload_codec=dc_codec, - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=50, ), @@ -488,9 +492,9 @@ async def test_external_converter_without_codec_does_not_encode_stored_bytes(sel assert driver._retrieve_calls == 1 async def test_external_converter_codec_independent_from_dc_codec(self): - """When both DataConverter.payload_codec and StorageOptions.payload_codec + """When both DataConverter.payload_codec and StorageConfig.payload_codec are set, the reference payload in history uses DataConverter.payload_codec - and the bytes stored by the driver use StorageOptions.payload_codec – + and the bytes stored by the driver use StorageConfig.payload_codec – independently.""" driver = InMemoryTestDriver() dc_codec = RecordingPayloadCodec("binary/dc-encoded") @@ -498,7 +502,7 @@ async def test_external_converter_codec_independent_from_dc_codec(self): converter = DataConverter( payload_codec=dc_codec, - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=50, payload_codec=ext_codec, @@ -534,7 +538,7 @@ async def test_external_converter_codec_independent_from_dc_codec(self): class TestMultiDriver: - """Tests for StorageOptions with multiple drivers.""" + """Tests for StorageConfig with multiple drivers.""" async def test_no_selector_uses_first_driver_for_store(self): """Without a driver_selector the first driver in the list handles all @@ -543,7 +547,7 @@ async def test_no_selector_uses_first_driver_for_store(self): second = InMemoryTestDriver("driver-second") converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[first, second], payload_size_threshold=50, ) @@ -576,7 +580,7 @@ async def test_no_selector_second_driver_is_retrieve_only(self): # Store with driver-b as the sole driver. store_converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver_b], payload_size_threshold=50, ) @@ -587,7 +591,7 @@ async def test_no_selector_second_driver_is_retrieve_only(self): # Retrieve with driver-a listed first, driver-b second. # The "driver-b" name in the reference must route to driver-b. retrieve_converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver_a, driver_b], payload_size_threshold=50, ) @@ -611,7 +615,7 @@ def selector(_ctx: object, payload: Payload) -> InMemoryTestDriver: return driver_a if payload.ByteSize() < 500 else driver_b converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver_a, driver_b], driver_selector=selector, payload_size_threshold=50, @@ -639,7 +643,7 @@ async def test_selector_returning_none_keeps_payload_inline(self): driver = InMemoryTestDriver("driver-a") converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], driver_selector=lambda _ctx, _payload: None, payload_size_threshold=50, @@ -658,12 +662,12 @@ async def test_selector_returning_none_keeps_payload_inline(self): async def test_selector_returns_unregistered_driver_raises(self): """A selector that returns a Driver whose name is not present in - StorageOptions.drivers raises RuntimeError during encode.""" + StorageConfig.drivers raises RuntimeError during encode.""" registered = InMemoryTestDriver("registered") unregistered = InMemoryTestDriver("not-in-list") converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[registered], driver_selector=lambda _ctx, _payload: unregistered, payload_size_threshold=50, @@ -686,7 +690,7 @@ async def test_duplicate_driver_names_warns_and_first_wins_for_retrieval(self): with warnings.catch_warnings(record=True) as caught: warnings.simplefilter("always") converter = DataConverter( - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[first, duplicate], payload_size_threshold=50, ) diff --git a/tests/worker/test_extstore.py b/tests/worker/test_extstore.py index 4da9b73cb..84492af6d 100644 --- a/tests/worker/test_extstore.py +++ b/tests/worker/test_extstore.py @@ -14,9 +14,9 @@ from temporalio.common import RetryPolicy from temporalio.exceptions import ActivityError, ApplicationError from temporalio.extstore import ( - DriverClaim, - DriverContext, - StorageOptions, + StorageConfig, + StorageDriverClaim, + StorageDriverContext, StorageWarning, ) from temporalio.testing._workflow import WorkflowEnvironment @@ -84,17 +84,17 @@ def __init__( async def store( self, - context: DriverContext, + context: StorageDriverContext, payloads: Sequence[Payload], - ) -> list[DriverClaim]: + ) -> list[StorageDriverClaim]: if self._no_store: return [] return await super().store(context, payloads) async def retrieve( self, - context: DriverContext, - claims: Sequence[DriverClaim], + context: StorageDriverContext, + claims: Sequence[StorageDriverClaim], ) -> list[Payload]: if self._no_retrieve: return [] @@ -119,7 +119,7 @@ async def test_extstore_activity_input_no_retrieve( namespace=env.client.namespace, data_converter=dataclasses.replace( temporalio.converter.default(), - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=1024, ), @@ -160,7 +160,7 @@ async def test_extstore_activity_result_no_store( namespace=env.client.namespace, data_converter=dataclasses.replace( temporalio.converter.default(), - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=1024, ), @@ -203,7 +203,7 @@ async def test_extstore_worker_missing_driver( namespace=env.client.namespace, data_converter=dataclasses.replace( temporalio.converter.default(), - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=1024, ), @@ -244,7 +244,7 @@ async def test_extstore_payload_not_found_fails_workflow( namespace=env.client.namespace, data_converter=dataclasses.replace( temporalio.converter.default(), - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[BadTestDriver(raise_payload_not_found=True)], payload_size_threshold=1024, ), @@ -288,7 +288,7 @@ async def _run_extstore_workflow_and_fetch_history( namespace=env.client.namespace, data_converter=dataclasses.replace( temporalio.converter.default(), - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=512, ), @@ -357,7 +357,7 @@ async def test_replay_extstore_history_succeeds_with_correct_extstore( workflows=[ExtStoreWorkflow], data_converter=dataclasses.replace( temporalio.converter.default(), - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[driver], payload_size_threshold=512, ), @@ -382,7 +382,7 @@ async def test_replay_extstore_history_fails_with_empty_driver( workflows=[ExtStoreWorkflow], data_converter=dataclasses.replace( temporalio.converter.default(), - external_storage=StorageOptions( + external_storage=StorageConfig( drivers=[InMemoryTestDriver()], payload_size_threshold=512, ), From cab33e2d9922db23725ff8eaa0d7b0cd023f7350 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:12:36 -0800 Subject: [PATCH 08/12] Remove unused import --- temporalio/worker/_workflow_instance.py | 1 - 1 file changed, 1 deletion(-) diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 9b02825ba..d7b6d27b9 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -57,7 +57,6 @@ import temporalio.common import temporalio.converter import temporalio.exceptions -import temporalio.extstore import temporalio.workflow from temporalio.service import __version__ From 3fffe91a981851f41493df9751f657d446564ca5 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 13:13:10 -0800 Subject: [PATCH 09/12] Formatting --- temporalio/converter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index f88e4205d..fb78d6ac2 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -928,7 +928,6 @@ def to_failure( failure: temporalio.api.failure.v1.Failure, ) -> None: """See base class.""" - # If already a failure error, use that if isinstance(exception, temporalio.exceptions.FailureError): self._error_to_failure(exception, payload_converter, failure) From 734088aea4733cb575c3ea1cb4cc1e6bed086083 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 15:21:50 -0800 Subject: [PATCH 10/12] Undo breaking changes --- temporalio/converter.py | 21 +++++++++++++-------- temporalio/extstore.py | 16 +++++++++++++--- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index fb78d6ac2..f103823cc 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1563,14 +1563,17 @@ async def _encode_payload( return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): - encoded_payloads = await self._encode_payload_sequence(payloads.payloads) - del payloads.payloads[:] - payloads.payloads.extend(encoded_payloads) + await self._external_storage_middleware.store_payloads(payloads) + if self.payload_codec: + await self.payload_codec.encode_wrapper(payloads) + self._validate_payload_limits(payloads.payloads) async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - result = await self._external_storage_middleware.store_payloads(payloads) + result = await self._external_storage_middleware.store_payload_sequence( + payloads + ) if self.payload_codec: result = await self.payload_codec.encode(result) self._validate_payload_limits(result) @@ -1585,9 +1588,9 @@ async def _decode_payload( return payload async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): - decoded_payloads = await self._decode_payload_sequence(payloads.payloads) - del payloads.payloads[:] - payloads.payloads.extend(decoded_payloads) + if self.payload_codec: + await self.payload_codec.decode_wrapper(payloads) + await self._external_storage_middleware.retrieve_payloads(payloads) async def _decode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] @@ -1595,7 +1598,9 @@ async def _decode_payload_sequence( result = list(payloads) if self.payload_codec: result = await self.payload_codec.decode(result) - result = await self._external_storage_middleware.retrieve_payloads(result) + result = await self._external_storage_middleware.retrieve_payload_sequence( + result + ) return result @staticmethod diff --git a/temporalio/extstore.py b/temporalio/extstore.py index 5198f234e..e469e3fd5 100644 --- a/temporalio/extstore.py +++ b/temporalio/extstore.py @@ -11,7 +11,7 @@ from typing_extensions import Self -from temporalio.api.common.v1 import Payload +from temporalio.api.common.v1 import Payload, Payloads from temporalio.converter import ( JSONPlainPayloadConverter, PayloadCodec, @@ -278,7 +278,12 @@ async def store_payload(self, payload: Payload) -> Payload: ) return reference_payload - async def store_payloads( + async def store_payloads(self, payloads: Payloads): + stored_payloads = await self.store_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def store_payload_sequence( self, payloads: Sequence[Payload], ) -> list[Payload]: @@ -392,7 +397,12 @@ async def retrieve_payload( return stored_payloads[0] - async def retrieve_payloads( + async def retrieve_payloads(self, payloads: Payloads): + stored_payloads = await self.retrieve_payload_sequence(payloads.payloads) + for i, payload in enumerate(stored_payloads): + payloads.payloads[i].CopyFrom(payload) + + async def retrieve_payload_sequence( self, payloads: Sequence[Payload], ) -> list[Payload]: From a0ddf5269484c98a8209f968f641c9572d628c82 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 15:22:17 -0800 Subject: [PATCH 11/12] Rename field --- temporalio/converter.py | 20 +++++++------------- 1 file changed, 7 insertions(+), 13 deletions(-) diff --git a/temporalio/converter.py b/temporalio/converter.py index f103823cc..18b2e67fd 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1373,9 +1373,7 @@ class DataConverter(WithSerializationContext): default: ClassVar[DataConverter] """Singleton default data converter.""" - _external_storage_middleware: temporalio.extstore._StorageImpl = dataclasses.field( - init=False - ) + _storage_impl: temporalio.extstore._StorageImpl = dataclasses.field(init=False) _payload_error_limits: _ServerPayloadErrorLimits | None = None """Server-reported limits for payloads.""" @@ -1556,14 +1554,14 @@ async def _encode_memo_existing( async def _encode_payload( self, payload: temporalio.api.common.v1.Payload ) -> temporalio.api.common.v1.Payload: - payload = await self._external_storage_middleware.store_payload(payload) + payload = await self._storage_impl.store_payload(payload) if self.payload_codec: payload = (await self.payload_codec.encode([payload]))[0] self._validate_payload_limits([payload]) return payload async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): - await self._external_storage_middleware.store_payloads(payloads) + await self._storage_impl.store_payloads(payloads) if self.payload_codec: await self.payload_codec.encode_wrapper(payloads) self._validate_payload_limits(payloads.payloads) @@ -1571,9 +1569,7 @@ async def _encode_payloads(self, payloads: temporalio.api.common.v1.Payloads): async def _encode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] ) -> list[temporalio.api.common.v1.Payload]: - result = await self._external_storage_middleware.store_payload_sequence( - payloads - ) + result = await self._storage_impl.store_payload_sequence(payloads) if self.payload_codec: result = await self.payload_codec.encode(result) self._validate_payload_limits(result) @@ -1584,13 +1580,13 @@ async def _decode_payload( ) -> temporalio.api.common.v1.Payload: if self.payload_codec: payload = (await self.payload_codec.decode([payload]))[0] - payload = await self._external_storage_middleware.retrieve_payload(payload) + payload = await self._storage_impl.retrieve_payload(payload) return payload async def _decode_payloads(self, payloads: temporalio.api.common.v1.Payloads): if self.payload_codec: await self.payload_codec.decode_wrapper(payloads) - await self._external_storage_middleware.retrieve_payloads(payloads) + await self._storage_impl.retrieve_payloads(payloads) async def _decode_payload_sequence( self, payloads: Sequence[temporalio.api.common.v1.Payload] @@ -1598,9 +1594,7 @@ async def _decode_payload_sequence( result = list(payloads) if self.payload_codec: result = await self.payload_codec.decode(result) - result = await self._external_storage_middleware.retrieve_payload_sequence( - result - ) + result = await self._storage_impl.retrieve_payload_sequence(result) return result @staticmethod From 89ac04f7e53b80361631bec7eb1fa792d6b09f43 Mon Sep 17 00:00:00 2001 From: jmaeagle99 <44687433+jmaeagle99@users.noreply.github.com> Date: Fri, 27 Feb 2026 15:24:46 -0800 Subject: [PATCH 12/12] Use mapping instead of dict --- temporalio/extstore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/extstore.py b/temporalio/extstore.py index e469e3fd5..54b446bec 100644 --- a/temporalio/extstore.py +++ b/temporalio/extstore.py @@ -6,7 +6,7 @@ import dataclasses import warnings from abc import ABC, abstractmethod -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from dataclasses import dataclass from typing_extensions import Self @@ -28,7 +28,7 @@ class StorageDriverClaim: This API is experimental. """ - data: dict[str, str] + data: Mapping[str, str] """Driver-defined data for identifying and retrieving an externally stored payload."""