Skip to content

Commit 78ea02e

Browse files
authored
support run multiple sql statements (#83)
* support run multiple sql statements * support getting compound return messages
1 parent 7474dc8 commit 78ea02e

File tree

5 files changed

+143
-79
lines changed

5 files changed

+143
-79
lines changed

proto/sqlflow/proto/sqlflow.proto

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ message Response {
4141
Head head = 1;
4242
Row row = 2;
4343
Message message = 3;
44+
EndOfExecution eoe = 4;
4445
}
4546
}
4647

@@ -66,3 +67,11 @@ message Row {
6667
message Message {
6768
string message = 1;
6869
}
70+
71+
// SQLFlow server may execute multiple SQL statements in one RPC call.
72+
// EndOfExecution message tells the client that execution of one SQL is
73+
// finished, the client should go to next loop to parse the result stream.
74+
message EndOfExecution {
75+
string sql = 1;
76+
int64 spent_time_seconds = 2;
77+
}

sqlflow/client.py

Lines changed: 43 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import sqlflow.proto.sqlflow_pb2_grpc as pb_grpc
1212

1313
from sqlflow.env_expand import EnvExpander, EnvExpanderError
14+
from sqlflow.rows import Rows
15+
from sqlflow.compound_message import CompoundMessage
1416

1517
_LOGGER = logging.getLogger(__name__)
1618
handler = logging.StreamHandler(sys.stdout)
@@ -21,62 +23,6 @@
2123
DEFAULT_TIMEOUT=3600 * 10
2224

2325

24-
class Rows:
25-
def __init__(self, column_names, rows_gen):
26-
"""Query result of sqlflow.client.Client.execute
27-
28-
:param column_names: column names
29-
:type column_names: list[str].
30-
:param rows_gen: rows generator
31-
:type rows_gen: generator
32-
"""
33-
self._column_names = column_names
34-
self._rows_gen = rows_gen
35-
self._rows = None
36-
37-
def column_names(self):
38-
"""Column names
39-
40-
:return: list[str]
41-
"""
42-
return self._column_names
43-
44-
def rows(self):
45-
"""Rows
46-
47-
Example:
48-
49-
>>> [r for r in rows.rows()]
50-
51-
:return: list generator
52-
"""
53-
if self._rows is None:
54-
self._rows = []
55-
for row in self._rows_gen():
56-
self._rows.append(row)
57-
yield row
58-
else:
59-
for row in self._rows:
60-
yield row
61-
62-
def __str__(self):
63-
return self.__repr__()
64-
65-
def __repr__(self):
66-
from prettytable import PrettyTable
67-
table = PrettyTable(self._column_names)
68-
for row in self.rows():
69-
table.add_row(row)
70-
return table.__str__()
71-
72-
def to_dataframe(self):
73-
"""Convert Rows to pandas.Dataframe
74-
75-
:return: pandas.Dataframe
76-
"""
77-
raise NotImplementedError
78-
79-
8026
class Client:
8127
def __init__(self, server_url=None, ca_crt=None):
8228
"""A minimum client that issues queries to and fetch results/logs from sqlflowserver.
@@ -159,27 +105,48 @@ def execute(self, operation):
159105
@classmethod
160106
def display(cls, stream_response):
161107
"""Display stream response like log or table.row"""
162-
first = next(stream_response)
163-
if first.WhichOneof('response') == 'message':
164-
# if the first line is html tag like,
165-
# merge all return strings then render the html on notebook
166-
if re.match(r'<[a-z][\s\S]*>.*', first.message.message):
167-
resp_list = [first.message.message]
168-
for res in stream_response:
169-
resp_list.append(res.message.message)
170-
from IPython.core.display import display, HTML
171-
display(HTML('\n'.join(resp_list)))
108+
compound_message = CompoundMessage()
109+
while True:
110+
try:
111+
first = next(stream_response)
112+
except StopIteration:
113+
break
114+
if first.WhichOneof('response') == 'message':
115+
# if the first line is html tag like,
116+
# merge all return strings then render the html on notebook
117+
if re.match(r'<[a-z][\s\S]*>.*', first.message.message):
118+
resp_list = [first.message.message]
119+
for res in stream_response:
120+
if res.WhichOneof('response') == 'eoe':
121+
_LOGGER.info("end execute %s, spent: %d" % (res.eoe.sql, res.eoe.spent_time_seconds))
122+
compound_message.add_html('\n'.join(resp_list), res)
123+
break
124+
resp_list.append(res.message.message)
125+
from IPython.core.display import display, HTML
126+
display(HTML('\n'.join(resp_list)))
127+
else:
128+
_LOGGER.info(first.message.message)
129+
all_messages = []
130+
all_messages.append(first.message.message)
131+
for res in stream_response:
132+
if res.WhichOneof('response') == 'eoe':
133+
_LOGGER.info("end execute %s, spent: %d" % (res.eoe.sql, res.eoe.spent_time_seconds))
134+
compound_message.add_message('\n'.join(all_messages), res)
135+
break
136+
_LOGGER.info(res.message.message)
137+
all_messages.append(res.message.message)
172138
else:
173-
_LOGGER.info(first.message.message)
174-
for res in stream_response:
175-
_LOGGER.info(res.message.message)
176-
else:
177-
column_names = [column_name for column_name in first.head.column_names]
178-
179-
def rows_gen():
180-
for res in stream_response:
181-
yield [cls._decode_any(a) for a in res.row.data]
182-
return Rows(column_names, rows_gen)
139+
column_names = [column_name for column_name in first.head.column_names]
140+
def rows_gen():
141+
for res in stream_response:
142+
if res.WhichOneof('response') == 'eoe':
143+
_LOGGER.info("end execute %s, spent: %d" % (res.eoe.sql, res.eoe.spent_time_seconds))
144+
break
145+
yield [cls._decode_any(a) for a in res.row.data]
146+
rows = Rows(column_names, rows_gen)
147+
_LOGGER.info(rows)
148+
compound_message.add_rows(rows, None)
149+
return compound_message
183150

184151
@classmethod
185152
def _decode_any(cls, any_message):

sqlflow/compound_message.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from sqlflow.rows import Rows
2+
3+
class CompoundMessage:
4+
def __init__(self):
5+
"""Message containing return result of several SQL statements
6+
CompoundMessage can not display in notebook since we need to
7+
output log messages for long running training sqls.
8+
"""
9+
self._messages = []
10+
self.TypeRows = 1
11+
self.TypeMessage = 2
12+
self.TypeHTML = 3
13+
14+
def add_rows(self, rows, eoe):
15+
assert(isinstance(rows, Rows))
16+
self._messages.append((rows, eoe, self.TypeRows))
17+
18+
def add_message(self, message, eoe):
19+
assert(isinstance(message, str))
20+
self._messages.append((message, eoe, self.TypeMessage))
21+
22+
def add_html(self, message, eoe):
23+
assert(isinstance(message, str))
24+
self._messages.append((message, eoe, self.TypeHTML))
25+
26+
def length(self):
27+
return len(self._messages)
28+
29+
def get(self, idx):
30+
return self._messages[idx][0]
31+

sqlflow/rows.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
2+
3+
class Rows:
4+
def __init__(self, column_names, rows_gen):
5+
"""Query result of sqlflow.client.Client.execute
6+
7+
:param column_names: column names
8+
:type column_names: list[str].
9+
:param rows_gen: rows generator
10+
:type rows_gen: generator
11+
"""
12+
self._column_names = column_names
13+
self._rows_gen = rows_gen
14+
self._rows = None
15+
16+
def column_names(self):
17+
"""Column names
18+
19+
:return: list[str]
20+
"""
21+
return self._column_names
22+
23+
def rows(self):
24+
"""Rows
25+
26+
Example:
27+
28+
>>> [r for r in rows.rows()]
29+
30+
:return: list generator
31+
"""
32+
if self._rows is None:
33+
self._rows = []
34+
for row in self._rows_gen():
35+
self._rows.append(row)
36+
yield row
37+
else:
38+
for row in self._rows:
39+
yield row
40+
41+
def __str__(self):
42+
return self.__repr__()
43+
44+
def __repr__(self):
45+
from prettytable import PrettyTable
46+
table = PrettyTable(self._column_names)
47+
for row in self.rows():
48+
table.add_row(row)
49+
return table.__str__()
50+
51+
def to_dataframe(self):
52+
"""Convert Rows to pandas.Dataframe
53+
54+
:return: pandas.Dataframe
55+
"""
56+
raise NotImplementedError

tests/test_client.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ def test_execute_stream(self):
5151
log_mock.info.assert_called_with("extended sql")
5252

5353
expected_table = MockServicer.get_test_table()
54-
rows = self.client.execute("select * from galaxy")
55-
assert expected_table["column_names"] == rows.column_names()
56-
assert expected_table["rows"] == [r for r in rows.rows()]
54+
compound_msg = self.client.execute("select * from galaxy")
55+
assert compound_msg.length() == 1
56+
assert expected_table["column_names"] == compound_msg.get(0).column_names()
57+
assert expected_table["rows"] == [r for r in compound_msg.get(0).rows()]
5758

5859
def test_decode_time(self):
5960
any_message = Any()

0 commit comments

Comments
 (0)