diff --git a/awscrt/mqtt5.py b/awscrt/mqtt5.py index 7c5e4f31f..012a63dc9 100644 --- a/awscrt/mqtt5.py +++ b/awscrt/mqtt5.py @@ -1228,8 +1228,10 @@ class PublishReceivedData: Args: publish_packet (PublishPacket): Data model of an `MQTT5 PUBLISH `_ packet. + acquire_puback_control (Callable): Call this function to prevent automatic PUBACK and take manual control of this PUBLISH message's PUBACK. Returns an opaque handle object that can be passed to Client.invoke_puback(). """ publish_packet: PublishPacket = None + acquire_puback_control: Callable = None @dataclass @@ -1434,7 +1436,8 @@ def _on_publish( correlation_data, subscription_identifiers_tuples, content_type, - user_properties_tuples): + user_properties_tuples, + acquire_puback_control_fn): if self._on_publish_cb is None: return @@ -1468,9 +1471,13 @@ def _on_publish( publish_packet.content_type = content_type publish_packet.user_properties = _init_user_properties(user_properties_tuples) - self._on_publish_cb(PublishReceivedData(publish_packet=publish_packet)) + # Create PublishReceivedData with the manual control callback + publish_data = PublishReceivedData( + publish_packet=publish_packet, + acquire_puback_control=acquire_puback_control_fn + ) - return + self._on_publish_cb(publish_data) def _on_lifecycle_stopped(self): if self._on_lifecycle_stopped_cb: @@ -1957,6 +1964,17 @@ def get_stats(self): result = _awscrt.mqtt5_client_get_stats(self._binding) return OperationStatisticsData(result[0], result[1], result[2], result[3]) + def invoke_puback(self, puback_control_handle): + """Sends a PUBACK packet for the given puback control handle. + + Args: + puback_control_handle: An opaque handle obtained from acquire_puback_control(). This handle cannot be created manually and must come from the acquire_puback_control() Callable within PublishReceivedData. + """ + + _awscrt.mqtt5_client_invoke_puback( + self._binding, + puback_control_handle) + def new_connection(self, on_connection_interrupted=None, on_connection_resumed=None, on_connection_success=None, on_connection_failure=None, on_connection_closed=None): from awscrt.mqtt import Connection diff --git a/crt/aws-c-mqtt b/crt/aws-c-mqtt index 1d512d927..37741c07a 160000 --- a/crt/aws-c-mqtt +++ b/crt/aws-c-mqtt @@ -1 +1 @@ -Subproject commit 1d512d92709f60b74e2cafa018e69a2e647f28e9 +Subproject commit 37741c07a23d35a700100a8fa6628127673ed012 diff --git a/source/module.c b/source/module.c index 0b752e03d..6e2452246 100644 --- a/source/module.c +++ b/source/module.c @@ -810,6 +810,7 @@ static PyMethodDef s_module_methods[] = { AWS_PY_METHOD_DEF(mqtt5_client_subscribe, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt5_client_unsubscribe, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt5_client_get_stats, METH_VARARGS), + AWS_PY_METHOD_DEF(mqtt5_client_invoke_puback, METH_VARARGS), AWS_PY_METHOD_DEF(mqtt5_ws_handshake_transform_complete, METH_VARARGS), /* MQTT Request Response Client */ diff --git a/source/mqtt5_client.c b/source/mqtt5_client.c index 243af6a0e..1700da6ba 100644 --- a/source/mqtt5_client.c +++ b/source/mqtt5_client.c @@ -218,6 +218,60 @@ static PyObject *s_aws_set_user_properties_to_PyObject( * Publish Handler ******************************************************************************/ +static const char *s_capsule_name_puback_control_handle = "aws_puback_control_handle"; + +struct puback_control_handle { + uint64_t control_id; +}; + +static void s_puback_control_handle_destructor(PyObject *capsule) { + struct puback_control_handle *handle = PyCapsule_GetPointer(capsule, s_capsule_name_puback_control_handle); + if (handle) { + aws_mem_release(aws_py_get_allocator(), handle); + } +} + +/* Callback context for manual PUBACK control */ +struct manual_puback_control_context { + struct aws_mqtt5_client *client; + struct aws_mqtt5_packet_publish_view *publish_packet; +}; + +static void s_manual_puback_control_context_destructor(PyObject *capsule) { + struct manual_puback_control_context *context = PyCapsule_GetPointer(capsule, "manual_puback_control_context"); + if (context) { + aws_mem_release(aws_py_get_allocator(), context); + } +} + +/* Function called from Python to set manual PUBACK control and return puback_control_id */ +PyObject *aws_py_mqtt5_client_acquire_puback(PyObject *self, PyObject *args) { + (void)args; + + struct manual_puback_control_context *context = PyCapsule_GetPointer(self, "manual_puback_control_context"); + if (!context || !context->publish_packet) { + PyErr_SetString(PyExc_ValueError, "Invalid manual PUBACK control context"); + return NULL; + } + + uint64_t puback_control_id = aws_mqtt5_client_acquire_puback(context->client, context->publish_packet); + + /* Create handle struct */ + struct puback_control_handle *handle = + aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct puback_control_handle)); + + handle->control_id = puback_control_id; + + /* Wrap in capsule */ + PyObject *capsule = PyCapsule_New(handle, s_capsule_name_puback_control_handle, s_puback_control_handle_destructor); + if (!capsule) { + aws_mem_release(aws_py_get_allocator(), handle); + return NULL; + } + + return capsule; +} + static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *publish_packet, void *user_data) { if (!user_data) { @@ -234,10 +288,46 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu PyObject *result = NULL; PyObject *subscription_identifier_list = NULL; PyObject *user_properties_list = NULL; + PyObject *manual_control_callback = NULL; + PyObject *control_context_capsule = NULL; size_t subscription_identifier_count = publish_packet->subscription_identifier_count; size_t user_property_count = publish_packet->user_property_count; + /* Create manual PUBACK control context */ + struct manual_puback_control_context *control_context = + aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct manual_puback_control_context)); + if (!control_context) { + PyErr_WriteUnraisable(PyErr_Occurred()); + goto cleanup; + } + + /* Set up the context with both client and publish packet */ + control_context->client = client->native; + control_context->publish_packet = (struct aws_mqtt5_packet_publish_view *)publish_packet; + + control_context_capsule = + PyCapsule_New(control_context, "manual_puback_control_context", s_manual_puback_control_context_destructor); + if (!control_context_capsule) { + aws_mem_release(aws_py_get_allocator(), control_context); + PyErr_WriteUnraisable(PyErr_Occurred()); + goto cleanup; + } + + /* Method definition for the manual control callback */ + static PyMethodDef method_def = { + "acquire_puback_control", + aws_py_mqtt5_client_acquire_puback, + METH_NOARGS, + "Take manual control of PUBACK for this message"}; + + /* Create the manual control callback function */ + manual_control_callback = PyCFunction_New(&method_def, control_context_capsule); + if (!manual_control_callback) { + PyErr_WriteUnraisable(PyErr_Occurred()); + goto cleanup; + } + /* Create list of uint32_t subscription identifier tuples */ subscription_identifier_list = PyList_New(subscription_identifier_count); if (!subscription_identifier_list) { @@ -261,7 +351,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu result = PyObject_CallMethod( client->client_core, "_on_publish", - "(y#iOs#OiOIOHs#y#Os#O)", + "(y#iOs#OiOIOHs#y#Os#OO)", /* y */ publish_packet->payload.ptr, /* # */ publish_packet->payload.len, /* i */ (int)publish_packet->qos, @@ -284,7 +374,9 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu /* O */ subscription_identifier_count > 0 ? subscription_identifier_list : Py_None, /* s */ publish_packet->content_type ? publish_packet->content_type->ptr : NULL, /* # */ publish_packet->content_type ? publish_packet->content_type->len : 0, - /* O */ user_property_count > 0 ? user_properties_list : Py_None); + /* O */ user_property_count > 0 ? user_properties_list : Py_None, + /* O */ manual_control_callback); + if (!result) { PyErr_WriteUnraisable(PyErr_Occurred()); } @@ -293,6 +385,8 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu Py_XDECREF(result); Py_XDECREF(subscription_identifier_list); Py_XDECREF(user_properties_list); + Py_XDECREF(manual_control_callback); + Py_XDECREF(control_context_capsule); PyGILState_Release(state); } @@ -1683,6 +1777,49 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args) { return NULL; } +/******************************************************************************* + * Invoke Puback + ******************************************************************************/ + +PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args) { + (void)self; + bool success = true; + + PyObject *impl_capsule; + PyObject *puback_handle_capsule; + + if (!PyArg_ParseTuple( + args, + "OO", + /* O */ &impl_capsule, + /* O */ &puback_handle_capsule)) { + return NULL; + } + + struct mqtt5_client_binding *client = PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt5_client); + if (!client) { + return NULL; + } + + /* Extract handle from capsule */ + struct puback_control_handle *handle = + PyCapsule_GetPointer(puback_handle_capsule, s_capsule_name_puback_control_handle); + if (!handle) { + PyErr_SetString(PyExc_TypeError, "Invalid PUBACK control handle"); + return NULL; + } + + if (aws_mqtt5_client_invoke_puback(client->native, handle->control_id, NULL)) { + PyErr_SetAwsLastError(); + success = false; + } + + if (success) { + Py_RETURN_NONE; + } + return NULL; +} + /******************************************************************************* * Subscribe ******************************************************************************/ diff --git a/source/mqtt5_client.h b/source/mqtt5_client.h index 46c135f82..b9bc54f16 100644 --- a/source/mqtt5_client.h +++ b/source/mqtt5_client.h @@ -14,6 +14,7 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_client_subscribe(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_client_unsubscribe(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_client_get_stats(PyObject *self, PyObject *args); +PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args); PyObject *aws_py_mqtt5_ws_handshake_transform_complete(PyObject *self, PyObject *args);