@@ -32,12 +32,13 @@ class ClientServerTest(unittest.TestCase):
3232 def setUpClass (cls ):
3333 # TODO: free port is better
3434 port = 8765
35+ cls .server_url = "localhost:%d" % port
3536 cls .event = threading .Event ()
36- cls .tmp_ca_dir , ca_crt , ca_key = generateTempCA ()
37- threading .Thread (target = _server , args = [port , cls .event , ca_crt , ca_key ]).start ()
37+ cls .tmp_ca_dir , cls . ca_crt , ca_key = generateTempCA ()
38+ threading .Thread (target = _server , args = [port , cls .event , cls . ca_crt , ca_key ]).start ()
3839 # wait for start
3940 time .sleep (1 )
40- cls .client = Client ("localhost:%d" % port , ca_crt )
41+ cls .client = Client (cls . server_url , cls . ca_crt )
4142
4243 @classmethod
4344 def tearDownClass (cls ):
@@ -51,10 +52,14 @@ def test_execute_stream(self):
5152 log_mock .info .assert_called_with ("extended sql" )
5253
5354 expected_table = MockServicer .get_test_table ()
54- compound_msg = self .client .execute ("select * from galaxy" )
55- assert compound_msg .length () == 1
56- assert expected_table ["column_names" ] == compound_msg .get (0 ).column_names ()
57- assert expected_table ["rows" ] == [r for r in compound_msg .get (0 ).rows ()]
55+ rows = self .client .execute ("select * from galaxy" )
56+ assert expected_table ["column_names" ] == rows .column_names ()
57+ assert expected_table ["rows" ] == [r for r in rows .rows ()]
58+
59+ def test_cmd (self ):
60+ assert subprocess .call (["sqlflow" , "--url" , self .server_url ,
61+ "--ca_crt" , self .ca_crt ,
62+ "select * from galaxy" ]) == 0
5863
5964 def test_decode_time (self ):
6065 any_message = Any ()
0 commit comments