Skip to content

Commit ef7df71

Browse files
Add support for command line (#88)
* Add support for command line * clean up
1 parent bdeab01 commit ef7df71

File tree

3 files changed

+17
-10
lines changed

3 files changed

+17
-10
lines changed

sqlflow/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,5 @@ def main():
1313
client = Client(server_url=args.url, ca_crt=args.ca_crt)
1414
for sql in args.sql:
1515
print("executing: {}".format(sql))
16-
for res in client.execute(sql):
17-
print(res)
16+
print(client.execute(sql))
17+

sqlflow/client.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def rows_gen():
146146
rows = Rows(column_names, rows_gen)
147147
_LOGGER.info(rows)
148148
compound_message.add_rows(rows, None)
149-
return compound_message
149+
if compound_message.length() == 1:
150+
return compound_message.get(0)
151+
return [compound_message.get(i) for i in range(compound_message.length())]
150152

151153
@classmethod
152154
def _decode_any(cls, any_message):

tests/test_client.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,13 @@ class ClientServerTest(unittest.TestCase):
3232
def setUpClass(cls):
3333
# TODO: free port is better
3434
port = 8765
35+
cls.server_url = "localhost:%d" % port
3536
cls.event = threading.Event()
36-
cls.tmp_ca_dir, ca_crt, ca_key = generateTempCA()
37-
threading.Thread(target=_server, args=[port, cls.event, ca_crt, ca_key]).start()
37+
cls.tmp_ca_dir, cls.ca_crt, ca_key = generateTempCA()
38+
threading.Thread(target=_server, args=[port, cls.event, cls.ca_crt, ca_key]).start()
3839
# wait for start
3940
time.sleep(1)
40-
cls.client = Client("localhost:%d" % port, ca_crt)
41+
cls.client = Client(cls.server_url, cls.ca_crt)
4142

4243
@classmethod
4344
def tearDownClass(cls):
@@ -51,10 +52,14 @@ def test_execute_stream(self):
5152
log_mock.info.assert_called_with("extended sql")
5253

5354
expected_table = MockServicer.get_test_table()
54-
compound_msg = self.client.execute("select * from galaxy")
55-
assert compound_msg.length() == 1
56-
assert expected_table["column_names"] == compound_msg.get(0).column_names()
57-
assert expected_table["rows"] == [r for r in compound_msg.get(0).rows()]
55+
rows = self.client.execute("select * from galaxy")
56+
assert expected_table["column_names"] == rows.column_names()
57+
assert expected_table["rows"] == [r for r in rows.rows()]
58+
59+
def test_cmd(self):
60+
assert subprocess.call(["sqlflow", "--url", self.server_url,
61+
"--ca_crt", self.ca_crt,
62+
"select * from galaxy"]) == 0
5863

5964
def test_decode_time(self):
6065
any_message = Any()

0 commit comments

Comments
 (0)