Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions awscrt/mqtt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,8 +1228,10 @@ class PublishReceivedData:

Args:
publish_packet (PublishPacket): Data model of an `MQTT5 PUBLISH <https://docs.oasis-open.org/mqtt/mqtt/v5.0/os/mqtt-v5.0-os.html#_Toc3901100>`_ packet.
acquire_puback_control (Callable): Call this function to prevent automatic PUBACK and take manual control of this PUBLISH message's PUBACK. Returns an opaque handle object that can be passed to Client.invoke_puback().
"""
publish_packet: PublishPacket = None
acquire_puback_control: Callable = None


@dataclass
Expand Down Expand Up @@ -1434,7 +1436,8 @@ def _on_publish(
correlation_data,
subscription_identifiers_tuples,
content_type,
user_properties_tuples):
user_properties_tuples,
acquire_puback_control_fn):
if self._on_publish_cb is None:
return

Expand Down Expand Up @@ -1468,9 +1471,13 @@ def _on_publish(
publish_packet.content_type = content_type
publish_packet.user_properties = _init_user_properties(user_properties_tuples)

self._on_publish_cb(PublishReceivedData(publish_packet=publish_packet))
# Create PublishReceivedData with the manual control callback
publish_data = PublishReceivedData(
publish_packet=publish_packet,
acquire_puback_control=acquire_puback_control_fn
)

return
self._on_publish_cb(publish_data)

def _on_lifecycle_stopped(self):
if self._on_lifecycle_stopped_cb:
Expand Down Expand Up @@ -1957,6 +1964,17 @@ def get_stats(self):
result = _awscrt.mqtt5_client_get_stats(self._binding)
return OperationStatisticsData(result[0], result[1], result[2], result[3])

def invoke_puback(self, puback_control_handle):
"""Sends a PUBACK packet for the given puback control handle.

Args:
puback_control_handle: An opaque handle obtained from acquire_puback_control(). This handle cannot be created manually and must come from the acquire_puback_control() Callable within PublishReceivedData.
"""

_awscrt.mqtt5_client_invoke_puback(
self._binding,
puback_control_handle)

def new_connection(self, on_connection_interrupted=None, on_connection_resumed=None,
on_connection_success=None, on_connection_failure=None, on_connection_closed=None):
from awscrt.mqtt import Connection
Expand Down
1 change: 1 addition & 0 deletions source/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -810,6 +810,7 @@ static PyMethodDef s_module_methods[] = {
AWS_PY_METHOD_DEF(mqtt5_client_subscribe, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_client_unsubscribe, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_client_get_stats, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_client_invoke_puback, METH_VARARGS),
AWS_PY_METHOD_DEF(mqtt5_ws_handshake_transform_complete, METH_VARARGS),

/* MQTT Request Response Client */
Expand Down
141 changes: 139 additions & 2 deletions source/mqtt5_client.c
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,60 @@ static PyObject *s_aws_set_user_properties_to_PyObject(
* Publish Handler
******************************************************************************/

static const char *s_capsule_name_puback_control_handle = "aws_puback_control_handle";

struct puback_control_handle {
uint64_t control_id;
};

static void s_puback_control_handle_destructor(PyObject *capsule) {
struct puback_control_handle *handle = PyCapsule_GetPointer(capsule, s_capsule_name_puback_control_handle);
if (handle) {
aws_mem_release(aws_py_get_allocator(), handle);
}
}

/* Callback context for manual PUBACK control */
struct manual_puback_control_context {
struct aws_mqtt5_client *client;
struct aws_mqtt5_packet_publish_view *publish_packet;
};

static void s_manual_puback_control_context_destructor(PyObject *capsule) {
struct manual_puback_control_context *context = PyCapsule_GetPointer(capsule, "manual_puback_control_context");
if (context) {
aws_mem_release(aws_py_get_allocator(), context);
}
}

/* Function called from Python to set manual PUBACK control and return puback_control_id */
PyObject *aws_py_mqtt5_client_acquire_puback(PyObject *self, PyObject *args) {
(void)args;

struct manual_puback_control_context *context = PyCapsule_GetPointer(self, "manual_puback_control_context");
if (!context || !context->publish_packet) {
PyErr_SetString(PyExc_ValueError, "Invalid manual PUBACK control context");
return NULL;
}

uint64_t puback_control_id = aws_mqtt5_client_acquire_puback(context->client, context->publish_packet);

/* Create handle struct */
struct puback_control_handle *handle =
aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct puback_control_handle));

handle->control_id = puback_control_id;

/* Wrap in capsule */
PyObject *capsule = PyCapsule_New(handle, s_capsule_name_puback_control_handle, s_puback_control_handle_destructor);
if (!capsule) {
aws_mem_release(aws_py_get_allocator(), handle);
return NULL;
}

return capsule;
}

static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *publish_packet, void *user_data) {

if (!user_data) {
Expand All @@ -234,10 +288,46 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
PyObject *result = NULL;
PyObject *subscription_identifier_list = NULL;
PyObject *user_properties_list = NULL;
PyObject *manual_control_callback = NULL;
PyObject *control_context_capsule = NULL;

size_t subscription_identifier_count = publish_packet->subscription_identifier_count;
size_t user_property_count = publish_packet->user_property_count;

/* Create manual PUBACK control context */
struct manual_puback_control_context *control_context =
aws_mem_calloc(aws_py_get_allocator(), 1, sizeof(struct manual_puback_control_context));
if (!control_context) {
PyErr_WriteUnraisable(PyErr_Occurred());
goto cleanup;
}

/* Set up the context with both client and publish packet */
control_context->client = client->native;
control_context->publish_packet = (struct aws_mqtt5_packet_publish_view *)publish_packet;

control_context_capsule =
PyCapsule_New(control_context, "manual_puback_control_context", s_manual_puback_control_context_destructor);
if (!control_context_capsule) {
aws_mem_release(aws_py_get_allocator(), control_context);
PyErr_WriteUnraisable(PyErr_Occurred());
goto cleanup;
}

/* Method definition for the manual control callback */
static PyMethodDef method_def = {
"acquire_puback_control",
aws_py_mqtt5_client_acquire_puback,
METH_NOARGS,
"Take manual control of PUBACK for this message"};

/* Create the manual control callback function */
manual_control_callback = PyCFunction_New(&method_def, control_context_capsule);
if (!manual_control_callback) {
PyErr_WriteUnraisable(PyErr_Occurred());
goto cleanup;
}

/* Create list of uint32_t subscription identifier tuples */
subscription_identifier_list = PyList_New(subscription_identifier_count);
if (!subscription_identifier_list) {
Expand All @@ -261,7 +351,7 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
result = PyObject_CallMethod(
client->client_core,
"_on_publish",
"(y#iOs#OiOIOHs#y#Os#O)",
"(y#iOs#OiOIOHs#y#Os#OO)",
/* y */ publish_packet->payload.ptr,
/* # */ publish_packet->payload.len,
/* i */ (int)publish_packet->qos,
Expand All @@ -284,7 +374,9 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
/* O */ subscription_identifier_count > 0 ? subscription_identifier_list : Py_None,
/* s */ publish_packet->content_type ? publish_packet->content_type->ptr : NULL,
/* # */ publish_packet->content_type ? publish_packet->content_type->len : 0,
/* O */ user_property_count > 0 ? user_properties_list : Py_None);
/* O */ user_property_count > 0 ? user_properties_list : Py_None,
/* O */ manual_control_callback);

if (!result) {
PyErr_WriteUnraisable(PyErr_Occurred());
}
Expand All @@ -293,6 +385,8 @@ static void s_on_publish_received(const struct aws_mqtt5_packet_publish_view *pu
Py_XDECREF(result);
Py_XDECREF(subscription_identifier_list);
Py_XDECREF(user_properties_list);
Py_XDECREF(manual_control_callback);
Py_XDECREF(control_context_capsule);
PyGILState_Release(state);
}

Expand Down Expand Up @@ -1683,6 +1777,49 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args) {
return NULL;
}

/*******************************************************************************
* Invoke Puback
******************************************************************************/

PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args) {
(void)self;
bool success = true;

PyObject *impl_capsule;
PyObject *puback_handle_capsule;

if (!PyArg_ParseTuple(
args,
"OO",
/* O */ &impl_capsule,
/* O */ &puback_handle_capsule)) {
return NULL;
}

struct mqtt5_client_binding *client = PyCapsule_GetPointer(impl_capsule, s_capsule_name_mqtt5_client);
if (!client) {
return NULL;
}

/* Extract handle from capsule */
struct puback_control_handle *handle =
PyCapsule_GetPointer(puback_handle_capsule, s_capsule_name_puback_control_handle);
if (!handle) {
PyErr_SetString(PyExc_TypeError, "Invalid PUBACK control handle");
return NULL;
}

if (aws_mqtt5_client_invoke_puback(client->native, handle->control_id, NULL)) {
PyErr_SetAwsLastError();
success = false;
}

if (success) {
Py_RETURN_NONE;
}
return NULL;
}

/*******************************************************************************
* Subscribe
******************************************************************************/
Expand Down
1 change: 1 addition & 0 deletions source/mqtt5_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ PyObject *aws_py_mqtt5_client_publish(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_subscribe(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_unsubscribe(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_get_stats(PyObject *self, PyObject *args);
PyObject *aws_py_mqtt5_client_invoke_puback(PyObject *self, PyObject *args);

PyObject *aws_py_mqtt5_ws_handshake_transform_complete(PyObject *self, PyObject *args);

Expand Down
Loading