@@ -21,14 +21,21 @@ def _get_running_loop():
2121 return asyncio .get_event_loop ()
2222
2323
24- class _DatagramProtocol :
24+ class _BaseDatagramProtocol ( asyncio . DatagramProtocol ) :
2525 def __init__ (self ):
2626 self .transport = None
27- self .recvfrom = None
2827
2928 def connection_made (self , transport ):
3029 self .transport = transport
3130
31+ def close (self ):
32+ if self .transport is not None :
33+ self .transport .close ()
34+
35+ class _SocketDatagramProtocol (_BaseDatagramProtocol ):
36+ def __init__ (self ):
37+ self .recvfrom = None
38+
3239 def datagram_received (self , data , addr ):
3340 if self .recvfrom and not self .recvfrom .done ():
3441 self .recvfrom .set_result ((data , addr ))
@@ -48,9 +55,18 @@ def connection_lost(self, exc):
4855 else :
4956 self .recvfrom .set_exception (exc )
5057
51- def close (self ):
52- if self .transport is not None :
53- self .transport .close ()
58+ class _ServerDatagramProtocol (_BaseDatagramProtocol ):
59+ def __init__ (self , datagram_received_cb , serving_future ):
60+ self .datagram_received_cb = datagram_received_cb
61+ self .serving = serving_future
62+
63+ def datagram_received (self , data , addr ):
64+ if self .datagram_received_cb :
65+ asyncio .ensure_future (self .datagram_received_cb (addr , self .transport , self , data ))
66+
67+ def connection_lost (self , exc ):
68+ if self .serving and not self .serving .done ():
69+ self .serving .set_result (True )
5470
5571
5672async def _maybe_wait_for (awaitable , timeout ):
@@ -63,7 +79,7 @@ async def _maybe_wait_for(awaitable, timeout):
6379 return await awaitable
6480
6581
66- class _DatagramSocket (dns ._asyncbackend .DatagramSocket ):
82+ class _BaseDatagramSocket (dns ._asyncbackend .DatagramSocket ):
6783 def __init__ (self , family , transport , protocol ):
6884 super ().__init__ (family , socket .SOCK_DGRAM )
6985 self .transport = transport
@@ -74,6 +90,20 @@ async def sendto(self, what, destination, timeout): # pragma: no cover
7490 self .transport .sendto (what , destination )
7591 return len (what )
7692
93+ async def close (self ):
94+ self .protocol .close ()
95+
96+ async def getpeername (self ):
97+ return self .transport .get_extra_info ("peername" )
98+
99+ async def getsockname (self ):
100+ return self .transport .get_extra_info ("sockname" )
101+
102+ async def getpeercert (self , timeout ):
103+ raise NotImplementedError
104+
105+
106+ class _SocketDatagramSocket (_BaseDatagramSocket ):
77107 async def recvfrom (self , size , timeout ):
78108 # ignore size as there's no way I know to tell protocol about it
79109 done = _get_running_loop ().create_future ()
@@ -85,17 +115,20 @@ async def recvfrom(self, size, timeout):
85115 finally :
86116 self .protocol .recvfrom = None
87117
88- async def close (self ):
89- self .protocol .close ()
90118
91- async def getpeername (self ):
92- return self .transport .get_extra_info ("peername" )
119+ class _ServerDatagramSocket (_BaseDatagramSocket ):
120+ def __init__ (self , family , transport , protocol , addr , data ):
121+ super ().__init__ (family , transport , protocol )
122+ self .addr = addr
123+ self .data = data
93124
94- async def getsockname (self ):
95- return self .transport .get_extra_info ("sockname" )
96-
97- async def getpeercert (self , timeout ):
98- raise NotImplementedError
125+ async def recvfrom (self , size , timeout ):
126+ if self .data is None :
127+ raise EOFError ("EOF" )
128+ # data always contains exactly one messaage
129+ result = (self .data , self .addr )
130+ self .data = None
131+ return result
99132
100133
101134class _StreamSocket (dns ._asyncbackend .StreamSocket ):
@@ -124,6 +157,30 @@ async def getpeercert(self, timeout):
124157 return self .writer .get_extra_info ("peercert" )
125158
126159
160+ class DatagramServer (dns ._asyncbackend .Server ):
161+ def __init__ (self , transport , protocol ):
162+ self .transport = transport
163+ self .protocol = protocol
164+
165+ async def serve_forever (self ):
166+ await self .protocol .serving
167+
168+ async def close (self ):
169+ if self .transport :
170+ self .transport .close ()
171+
172+
173+ class StreamServer (dns ._asyncbackend .Server ):
174+ def __init__ (self , server ):
175+ self .server = server
176+
177+ async def serve_forever (self ):
178+ await self .server .serve_forever ()
179+
180+ async def close (self ):
181+ await self .server .close ()
182+
183+
127184if dns ._features .have ("doh" ):
128185 import anyio
129186 import httpcore
@@ -234,13 +291,13 @@ async def make_socket(
234291 # proper fix for [#637].
235292 source = (dns .inet .any_for_af (af ), 0 )
236293 transport , protocol = await loop .create_datagram_endpoint (
237- _DatagramProtocol , # type: ignore
294+ _SocketDatagramProtocol , # type: ignore
238295 local_addr = source ,
239296 family = af ,
240297 proto = proto ,
241298 remote_addr = destination ,
242299 )
243- return _DatagramSocket (af , transport , protocol )
300+ return _SocketDatagramSocket (af , transport , protocol )
244301 elif socktype == socket .SOCK_STREAM :
245302 if destination is None :
246303 # This shouldn't happen, but we check to make code analysis software
@@ -270,10 +327,22 @@ async def make_server(
270327 socktype ,
271328 addr ,
272329 ):
330+ loop = _get_running_loop ()
273331 if socktype == socket .SOCK_DGRAM :
274- raise NotImplementedError (
275- "server not necessary for datagram, use make_socket() instead"
276- ) # pragma: no cover
332+ if _is_win32 and addr is None :
333+ # Win32 wants explicit binding before recvfrom(). This is the
334+ # proper fix for [#637].
335+ addr = (dns .inet .any_for_af (af ), 0 )
336+ async def handle_udp (addr , transport , protocol , data ):
337+ sock_udp = _ServerDatagramSocket (af , transport , protocol , addr , data )
338+ await client_connected_cb (sock_udp )
339+ done = _get_running_loop ().create_future ()
340+ transport , protocol = await loop .create_datagram_endpoint (
341+ lambda : _ServerDatagramProtocol (handle_udp , done ),
342+ local_addr = addr ,
343+ family = af ,
344+ )
345+ return DatagramServer (transport , protocol )
277346 elif socktype == socket .SOCK_STREAM :
278347 async def handle_tcp (r , w ):
279348 sock_tcp = _StreamSocket (af , r , w )
0 commit comments