Skip to content

Commit 1489e2c

Browse files
Create dedicated datagram socket for server
The previous datagram socket for client use only allowed receiving a single message before it has been read (implementation via future). For server use cases we would like to receive messages continuously. This commit implements a dummy socket embedding received data such that dns.asyncquery.receive_udp() can still be used to read the message in order to keep the interface as close as possible to the TCP implementation. Also implementing a proper server class as public interface.
1 parent 659bb6f commit 1489e2c

File tree

3 files changed

+111
-26
lines changed

3 files changed

+111
-26
lines changed

dns/_asyncbackend.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,20 @@ async def recv(self, size, timeout):
6565
raise NotImplementedError
6666

6767

68+
class Server: # pragma: no cover
69+
async def serve_forever(self):
70+
raise NotImplementedError
71+
72+
async def close(self):
73+
raise NotImplementedError
74+
75+
async def __aenter__(self):
76+
return self
77+
78+
async def __aexit__(self, exc_type, exc_value, traceback):
79+
await self.close()
80+
81+
6882
class NullTransport:
6983
async def connect_tcp(self, host, port, timeout, local_address):
7084
raise NotImplementedError

dns/_asyncio_backend.py

Lines changed: 89 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ def _get_running_loop():
2121
return asyncio.get_event_loop()
2222

2323

24-
class _DatagramProtocol:
24+
class _BaseDatagramProtocol(asyncio.DatagramProtocol):
2525
def __init__(self):
2626
self.transport = None
27-
self.recvfrom = None
2827

2928
def connection_made(self, transport):
3029
self.transport = transport
3130

31+
def close(self):
32+
if self.transport is not None:
33+
self.transport.close()
34+
35+
class _SocketDatagramProtocol(_BaseDatagramProtocol):
36+
def __init__(self):
37+
self.recvfrom = None
38+
3239
def datagram_received(self, data, addr):
3340
if self.recvfrom and not self.recvfrom.done():
3441
self.recvfrom.set_result((data, addr))
@@ -48,9 +55,18 @@ def connection_lost(self, exc):
4855
else:
4956
self.recvfrom.set_exception(exc)
5057

51-
def close(self):
52-
if self.transport is not None:
53-
self.transport.close()
58+
class _ServerDatagramProtocol(_BaseDatagramProtocol):
59+
def __init__(self, datagram_received_cb, serving_future):
60+
self.datagram_received_cb = datagram_received_cb
61+
self.serving = serving_future
62+
63+
def datagram_received(self, data, addr):
64+
if self.datagram_received_cb:
65+
asyncio.ensure_future(self.datagram_received_cb(addr, self.transport, self, data))
66+
67+
def connection_lost(self, exc):
68+
if self.serving and not self.serving.done():
69+
self.serving.set_result(True)
5470

5571

5672
async def _maybe_wait_for(awaitable, timeout):
@@ -63,7 +79,7 @@ async def _maybe_wait_for(awaitable, timeout):
6379
return await awaitable
6480

6581

66-
class _DatagramSocket(dns._asyncbackend.DatagramSocket):
82+
class _BaseDatagramSocket(dns._asyncbackend.DatagramSocket):
6783
def __init__(self, family, transport, protocol):
6884
super().__init__(family, socket.SOCK_DGRAM)
6985
self.transport = transport
@@ -74,6 +90,20 @@ async def sendto(self, what, destination, timeout): # pragma: no cover
7490
self.transport.sendto(what, destination)
7591
return len(what)
7692

93+
async def close(self):
94+
self.protocol.close()
95+
96+
async def getpeername(self):
97+
return self.transport.get_extra_info("peername")
98+
99+
async def getsockname(self):
100+
return self.transport.get_extra_info("sockname")
101+
102+
async def getpeercert(self, timeout):
103+
raise NotImplementedError
104+
105+
106+
class _SocketDatagramSocket(_BaseDatagramSocket):
77107
async def recvfrom(self, size, timeout):
78108
# ignore size as there's no way I know to tell protocol about it
79109
done = _get_running_loop().create_future()
@@ -85,17 +115,20 @@ async def recvfrom(self, size, timeout):
85115
finally:
86116
self.protocol.recvfrom = None
87117

88-
async def close(self):
89-
self.protocol.close()
90118

91-
async def getpeername(self):
92-
return self.transport.get_extra_info("peername")
119+
class _ServerDatagramSocket(_BaseDatagramSocket):
120+
def __init__(self, family, transport, protocol, addr, data):
121+
super().__init__(family, transport, protocol)
122+
self.addr = addr
123+
self.data = data
93124

94-
async def getsockname(self):
95-
return self.transport.get_extra_info("sockname")
96-
97-
async def getpeercert(self, timeout):
98-
raise NotImplementedError
125+
async def recvfrom(self, size, timeout):
126+
if self.data is None:
127+
raise EOFError("EOF")
128+
# data always contains exactly one messaage
129+
result = (self.data, self.addr)
130+
self.data = None
131+
return result
99132

100133

101134
class _StreamSocket(dns._asyncbackend.StreamSocket):
@@ -124,6 +157,30 @@ async def getpeercert(self, timeout):
124157
return self.writer.get_extra_info("peercert")
125158

126159

160+
class DatagramServer(dns._asyncbackend.Server):
161+
def __init__(self, transport, protocol):
162+
self.transport = transport
163+
self.protocol = protocol
164+
165+
async def serve_forever(self):
166+
await self.protocol.serving
167+
168+
async def close(self):
169+
if self.transport:
170+
self.transport.close()
171+
172+
173+
class StreamServer(dns._asyncbackend.Server):
174+
def __init__(self, server):
175+
self.server = server
176+
177+
async def serve_forever(self):
178+
await self.server.serve_forever()
179+
180+
async def close(self):
181+
await self.server.close()
182+
183+
127184
if dns._features.have("doh"):
128185
import anyio
129186
import httpcore
@@ -234,13 +291,13 @@ async def make_socket(
234291
# proper fix for [#637].
235292
source = (dns.inet.any_for_af(af), 0)
236293
transport, protocol = await loop.create_datagram_endpoint(
237-
_DatagramProtocol, # type: ignore
294+
_SocketDatagramProtocol, # type: ignore
238295
local_addr=source,
239296
family=af,
240297
proto=proto,
241298
remote_addr=destination,
242299
)
243-
return _DatagramSocket(af, transport, protocol)
300+
return _SocketDatagramSocket(af, transport, protocol)
244301
elif socktype == socket.SOCK_STREAM:
245302
if destination is None:
246303
# This shouldn't happen, but we check to make code analysis software
@@ -270,10 +327,22 @@ async def make_server(
270327
socktype,
271328
addr,
272329
):
330+
loop = _get_running_loop()
273331
if socktype == socket.SOCK_DGRAM:
274-
raise NotImplementedError(
275-
"server not necessary for datagram, use make_socket() instead"
276-
) # pragma: no cover
332+
if _is_win32 and addr is None:
333+
# Win32 wants explicit binding before recvfrom(). This is the
334+
# proper fix for [#637].
335+
addr = (dns.inet.any_for_af(af), 0)
336+
async def handle_udp(addr, transport, protocol, data):
337+
sock_udp = _ServerDatagramSocket(af, transport, protocol, addr, data)
338+
await client_connected_cb(sock_udp)
339+
done = _get_running_loop().create_future()
340+
transport, protocol = await loop.create_datagram_endpoint(
341+
lambda: _ServerDatagramProtocol(handle_udp, done),
342+
local_addr=addr,
343+
family=af,
344+
)
345+
return DatagramServer(transport, protocol)
277346
elif socktype == socket.SOCK_STREAM:
278347
async def handle_tcp(r, w):
279348
sock_tcp = _StreamSocket(af, r, w)

dns/asyncserver.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,7 @@ async def udp_serve(
2020
ignore_errors: bool = False,
2121
backend: dns.asyncbackend.Backend | None = None,
2222
) -> None:
23-
if not backend:
24-
backend = dns.asyncbackend.get_default_backend()
25-
af = dns.inet.af_for_address(host)
26-
addr = (host, port)
27-
sock = await backend.make_socket(af, socket.SOCK_DGRAM, 0, addr)
28-
while True:
23+
async def handle_udp(sock: dns.asyncbackend.DatagramSocket):
2924
try:
3025
(m, _, from_address) = await dns.asyncquery.receive_udp(
3126
sock,
@@ -44,6 +39,13 @@ async def udp_serve(
4439
except:
4540
pass
4641

42+
if not backend:
43+
backend = dns.asyncbackend.get_default_backend()
44+
af = dns.inet.af_for_address(host)
45+
addr = (host, port)
46+
server = await backend.make_server(handle_udp, af, socket.SOCK_DGRAM, addr)
47+
await server.serve_forever()
48+
4749

4850
async def tcp_serve(
4951
cb: Callable[[dns.message.Message, str], Awaitable[dns.message.Message]],

0 commit comments

Comments
 (0)