2424import yaml
2525
2626from six .moves .urllib .parse import urlencode , quote_plus , urlparse , urlunparse
27+ from six import StringIO
2728
2829from websocket import WebSocket , ABNF , enableTrace
2930
3334ERROR_CHANNEL = 3
3435RESIZE_CHANNEL = 4
3536
37+ class _IgnoredIO :
38+ def write (self , _x ):
39+ pass
40+
41+ def getvalue (self ):
42+ raise TypeError ("Tried to read_all() from a WSClient configured to not capture. Did you mean `capture_all=True`?" )
43+
3644
3745class WSClient :
38- def __init__ (self , configuration , url , headers ):
46+ def __init__ (self , configuration , url , headers , capture_all ):
3947 """A websocket client with support for channels.
4048
4149 Exec command uses different channels for different streams. for
@@ -47,7 +55,10 @@ def __init__(self, configuration, url, headers):
4755 header = []
4856 self ._connected = False
4957 self ._channels = {}
50- self ._all = ""
58+ if capture_all :
59+ self ._all = StringIO ()
60+ else :
61+ self ._all = _IgnoredIO ()
5162
5263 # We just need to pass the Authorization, ignore all the other
5364 # http headers we get from the generated code
@@ -157,8 +168,8 @@ def read_all(self):
157168 TODO: Maybe we can process this and return a more meaningful map with
158169 channels mapped for each input.
159170 """
160- out = self ._all
161- self ._all = ""
171+ out = self ._all . getvalue ()
172+ self ._all = self . _all . __class__ ()
162173 self ._channels = {}
163174 return out
164175
@@ -195,7 +206,7 @@ def update(self, timeout=0):
195206 if channel in [STDOUT_CHANNEL , STDERR_CHANNEL ]:
196207 # keeping all messages in the order they received
197208 # for non-blocking call.
198- self ._all += data
209+ self ._all . write ( data )
199210 if channel not in self ._channels :
200211 self ._channels [channel ] = data
201212 else :
@@ -257,6 +268,7 @@ def websocket_call(configuration, *args, **kwargs):
257268 url = args [1 ]
258269 _request_timeout = kwargs .get ("_request_timeout" , 60 )
259270 _preload_content = kwargs .get ("_preload_content" , True )
271+ capture_all = kwargs .get ("capture_all" , True )
260272 headers = kwargs .get ("headers" )
261273
262274 # Expand command parameter list to indivitual command params
@@ -272,7 +284,7 @@ def websocket_call(configuration, *args, **kwargs):
272284 url += '?' + urlencode (query_params )
273285
274286 try :
275- client = WSClient (configuration , get_websocket_url (url ), headers )
287+ client = WSClient (configuration , get_websocket_url (url ), headers , capture_all )
276288 if not _preload_content :
277289 return client
278290 client .run_forever (timeout = _request_timeout )
0 commit comments