@@ -238,33 +238,51 @@ def __init__(self, websocket, ports):
238238
239239 self .websocket = websocket
240240 self .local_ports = {}
241- for ix , local_remote in enumerate (ports ):
242- self .local_ports [local_remote [0 ]] = self ._Port (ix , local_remote [1 ])
241+ for ix , port_number in enumerate (ports ):
242+ self .local_ports [port_number ] = self ._Port (ix , port_number )
243+ # There is a thread run per PortForward instance which performs the translation between the
244+ # raw socket data sent by the python application and the websocket protocol. This thread
245+ # terminates after either side has closed all ports, and after flushing all pending data.
243246 threading .Thread (
244- name = "Kubernetes port forward proxy" , target = self ._proxy , daemon = True
247+ name = "Kubernetes port forward proxy: %s" % ', ' .join ([str (port ) for port in ports ]),
248+ target = self ._proxy ,
249+ daemon = True
245250 ).start ()
246251
247- def socket (self , local_number ):
248- if local_number not in self .local_ports :
252+ def socket (self , port_number ):
253+ if port_number not in self .local_ports :
249254 raise ValueError ("Invalid port number" )
250- return self .local_ports [local_number ].socket
255+ return self .local_ports [port_number ].socket
251256
252- def error (self , local_number ):
253- if local_number not in self .local_ports :
257+ def error (self , port_number ):
258+ if port_number not in self .local_ports :
254259 raise ValueError ("Invalid port number" )
255- return self .local_ports [local_number ].error
260+ return self .local_ports [port_number ].error
256261
257262 def close (self ):
258263 for port in self .local_ports .values ():
259264 port .socket .close ()
260265
261266 class _Port :
262- def __init__ (self , ix , remote_number ):
263- self .remote_number = remote_number
267+ def __init__ (self , ix , port_number ):
268+ # The remote port number
269+ self .port_number = port_number
270+ # The websocket channel byte number for this port
264271 self .channel = bytes ([ix * 2 ])
272+ # A socket pair is created to provide a means of translating the data flow
273+ # between the python application and the kubernetes websocket. The self.python
274+ # half of the socket pair is used by the _proxy method to receive and send data
275+ # to the running python application.
265276 s , self .python = socket .socketpair (socket .AF_UNIX , socket .SOCK_STREAM )
277+ # The self.socket half of the pair is used by the python application to send
278+ # and receive data to the eventual pod port. It is wrapped in the _Socket class
279+ # because a socket pair is an AF_UNIX socket, not a AF_NET socket. This allows
280+ # intercepting setting AF_INET socket options that would error against an AD_UNIX
281+ # socket.
266282 self .socket = self ._Socket (s )
283+ # Data accumulated from the websocket to be sent to the python application.
267284 self .data = b''
285+ # All data sent from kubernetes on the port error channel.
268286 self .error = None
269287
270288 class _Socket :
@@ -285,42 +303,44 @@ def setsockopt(self, level, optname, value):
285303 def _proxy (self ):
286304 channel_ports = []
287305 channel_initialized = []
288- python_ports = {}
289- rlist = []
306+ local_ports = {}
290307 for port in self .local_ports .values ():
291308 # Setup the data channel for this port number
292309 channel_ports .append (port )
293310 channel_initialized .append (False )
294311 # Setup the error channel for this port number
295312 channel_ports .append (port )
296313 channel_initialized .append (False )
297- python_ports [ port .python ] = port
298- rlist . append ( port .python )
299- rlist . append ( self . websocket . sock )
314+ port .python . setblocking ( True )
315+ local_ports [ port .python ] = port
316+ # The data to send on the websocket socket
300317 kubernetes_data = b''
301318 while True :
302- wlist = []
319+ rlist = [] # List of sockets to read from
320+ wlist = [] # List of sockets to write to
321+ if self .websocket .connected :
322+ rlist .append (self .websocket )
323+ if kubernetes_data :
324+ wlist .append (self .websocket )
325+ all_closed = True
303326 for port in self .local_ports .values ():
304- if port .data :
305- wlist .append (port .python )
306- if kubernetes_data :
307- wlist .append (self .websocket .sock )
327+ if port .python .fileno () != - 1 :
328+ if port .data :
329+ wlist .append (port .python )
330+ all_closed = False
331+ else :
332+ if self .websocket .connected :
333+ rlist .append (port .python )
334+ all_closed = False
335+ else :
336+ port .python .close ()
337+ if all_closed and (not self .websocket .connected or not kubernetes_data ):
338+ self .websocket .close ()
339+ return
308340 r , w , _ = select .select (rlist , wlist , [])
309- for s in w :
310- if s == self .websocket .sock :
311- sent = self .websocket .sock .send (kubernetes_data )
312- kubernetes_data = kubernetes_data [sent :]
313- else :
314- port = python_ports [s ]
315- sent = port .python .send (port .data )
316- port .data = port .data [sent :]
317- for s in r :
318- if s == self .websocket .sock :
341+ for sock in r :
342+ if sock == self .websocket :
319343 opcode , frame = self .websocket .recv_data_frame (True )
320- if opcode == ABNF .OPCODE_CLOSE :
321- for port in self .local_ports .values ():
322- port .python .close ()
323- return
324344 if opcode == ABNF .OPCODE_BINARY :
325345 if not frame .data :
326346 raise RuntimeError ("Unexpected frame data size" )
@@ -341,27 +361,32 @@ def _proxy(self):
341361 "Unexpected initial channel frame data size"
342362 )
343363 port_number = frame .data [1 ] + (frame .data [2 ] * 256 )
344- if port_number != port .remote_number :
364+ if port_number != port .port_number :
345365 raise RuntimeError (
346366 "Unexpected port number in initial channel frame: " + str (port_number )
347367 )
348368 channel_initialized [channel ] = True
349- elif opcode not in (ABNF .OPCODE_PING , ABNF .OPCODE_PONG ):
369+ elif opcode not in (ABNF .OPCODE_PING , ABNF .OPCODE_PONG , ABNF . OPCODE_CLOSE ):
350370 raise RuntimeError ("Unexpected websocket opcode: " + str (opcode ))
351371 else :
352- port = python_ports [ s ]
372+ port = local_ports [ sock ]
353373 data = port .python .recv (1024 * 1024 )
354374 if data :
355375 kubernetes_data += ABNF .create_frame (
356376 port .channel + data ,
357377 ABNF .OPCODE_BINARY ,
358378 ).format ()
359379 else :
360- port .python .close ()
361- rlist .remove (s )
362- if len (rlist ) == 1 :
363- self .websocket .close ()
364- return
380+ if not port .data :
381+ port .python .close ()
382+ for sock in w :
383+ if sock == self .websocket :
384+ sent = self .websocket .sock .send (kubernetes_data )
385+ kubernetes_data = kubernetes_data [sent :]
386+ else :
387+ port = local_ports [sock ]
388+ sent = port .python .send (port .data )
389+ port .data = port .data [sent :]
365390
366391
367392def get_websocket_url (url , query_params = None ):
@@ -451,38 +476,18 @@ def portforward_call(configuration, _method, url, **kwargs):
451476 query_params = kwargs .get ("query_params" )
452477
453478 ports = []
454- for ix in range (len (query_params )):
455- if query_params [ix ][0 ] == 'ports' :
456- remote_ports = []
457- for port in query_params [ix ][1 ].split (',' ):
479+ for param , value in query_params :
480+ if param == 'ports' :
481+ for port in value .split (',' ):
458482 try :
459- local_remote = port .split (':' )
460- if len (local_remote ) > 2 :
461- raise ValueError
462- if len (local_remote ) == 1 :
463- local_remote [0 ] = int (local_remote [0 ])
464- if not (0 < local_remote [0 ] < 65536 ):
465- raise ValueError
466- local_remote .append (local_remote [0 ])
467- elif len (local_remote ) == 2 :
468- if local_remote [0 ]:
469- local_remote [0 ] = int (local_remote [0 ])
470- if not (0 <= local_remote [0 ] < 65536 ):
471- raise ValueError
472- else :
473- local_remote [0 ] = 0
474- local_remote [1 ] = int (local_remote [1 ])
475- if not (0 < local_remote [1 ] < 65536 ):
476- raise ValueError
477- if not local_remote [0 ]:
478- local_remote [0 ] = len (ports ) + 1
479- else :
480- raise ValueError
481- ports .append (local_remote )
482- remote_ports .append (str (local_remote [1 ]))
483+ port_number = int (port )
483484 except ValueError :
484- raise ApiValueError ("Invalid port number `" + port + "`" )
485- query_params [ix ] = ('ports' , ',' .join (remote_ports ))
485+ raise ApiValueError ("Invalid port number: %s" % port )
486+ if not (0 < port_number < 65536 ):
487+ raise ApiValueError ("Port number must be between 0 and 65536: %s" % port )
488+ if port_number in ports :
489+ raise ApiValueError ("Duplicate port numbers: %s" % port )
490+ ports .append (port_number )
486491 if not ports :
487492 raise ApiValueError ("Missing required parameter `ports`" )
488493
0 commit comments