|
11 | 11 | import sqlflow.proto.sqlflow_pb2_grpc as pb_grpc |
12 | 12 |
|
13 | 13 | from sqlflow.env_expand import EnvExpander, EnvExpanderError |
| 14 | +from sqlflow.rows import Rows |
| 15 | +from sqlflow.compound_message import CompoundMessage |
14 | 16 |
|
15 | 17 | _LOGGER = logging.getLogger(__name__) |
16 | 18 | handler = logging.StreamHandler(sys.stdout) |
|
21 | 23 | DEFAULT_TIMEOUT=3600 * 10 |
22 | 24 |
|
23 | 25 |
|
24 | | -class Rows: |
25 | | - def __init__(self, column_names, rows_gen): |
26 | | - """Query result of sqlflow.client.Client.execute |
27 | | -
|
28 | | - :param column_names: column names |
29 | | - :type column_names: list[str]. |
30 | | - :param rows_gen: rows generator |
31 | | - :type rows_gen: generator |
32 | | - """ |
33 | | - self._column_names = column_names |
34 | | - self._rows_gen = rows_gen |
35 | | - self._rows = None |
36 | | - |
37 | | - def column_names(self): |
38 | | - """Column names |
39 | | -
|
40 | | - :return: list[str] |
41 | | - """ |
42 | | - return self._column_names |
43 | | - |
44 | | - def rows(self): |
45 | | - """Rows |
46 | | -
|
47 | | - Example: |
48 | | -
|
49 | | - >>> [r for r in rows.rows()] |
50 | | -
|
51 | | - :return: list generator |
52 | | - """ |
53 | | - if self._rows is None: |
54 | | - self._rows = [] |
55 | | - for row in self._rows_gen(): |
56 | | - self._rows.append(row) |
57 | | - yield row |
58 | | - else: |
59 | | - for row in self._rows: |
60 | | - yield row |
61 | | - |
62 | | - def __str__(self): |
63 | | - return self.__repr__() |
64 | | - |
65 | | - def __repr__(self): |
66 | | - from prettytable import PrettyTable |
67 | | - table = PrettyTable(self._column_names) |
68 | | - for row in self.rows(): |
69 | | - table.add_row(row) |
70 | | - return table.__str__() |
71 | | - |
72 | | - def to_dataframe(self): |
73 | | - """Convert Rows to pandas.Dataframe |
74 | | -
|
75 | | - :return: pandas.Dataframe |
76 | | - """ |
77 | | - raise NotImplementedError |
78 | | - |
79 | | - |
80 | 26 | class Client: |
81 | 27 | def __init__(self, server_url=None, ca_crt=None): |
82 | 28 | """A minimum client that issues queries to and fetch results/logs from sqlflowserver. |
@@ -159,27 +105,48 @@ def execute(self, operation): |
159 | 105 | @classmethod |
160 | 106 | def display(cls, stream_response): |
161 | 107 | """Display stream response like log or table.row""" |
162 | | - first = next(stream_response) |
163 | | - if first.WhichOneof('response') == 'message': |
164 | | - # if the first line is html tag like, |
165 | | - # merge all return strings then render the html on notebook |
166 | | - if re.match(r'<[a-z][\s\S]*>.*', first.message.message): |
167 | | - resp_list = [first.message.message] |
168 | | - for res in stream_response: |
169 | | - resp_list.append(res.message.message) |
170 | | - from IPython.core.display import display, HTML |
171 | | - display(HTML('\n'.join(resp_list))) |
| 108 | + compound_message = CompoundMessage() |
| 109 | + while True: |
| 110 | + try: |
| 111 | + first = next(stream_response) |
| 112 | + except StopIteration: |
| 113 | + break |
| 114 | + if first.WhichOneof('response') == 'message': |
| 115 | + # if the first line is html tag like, |
| 116 | + # merge all return strings then render the html on notebook |
| 117 | + if re.match(r'<[a-z][\s\S]*>.*', first.message.message): |
| 118 | + resp_list = [first.message.message] |
| 119 | + for res in stream_response: |
| 120 | + if res.WhichOneof('response') == 'eoe': |
| 121 | + _LOGGER.info("end execute %s, spent: %d" % (res.eoe.sql, res.eoe.spent_time_seconds)) |
| 122 | + compound_message.add_html('\n'.join(resp_list), res) |
| 123 | + break |
| 124 | + resp_list.append(res.message.message) |
| 125 | + from IPython.core.display import display, HTML |
| 126 | + display(HTML('\n'.join(resp_list))) |
| 127 | + else: |
| 128 | + _LOGGER.info(first.message.message) |
| 129 | + all_messages = [] |
| 130 | + all_messages.append(first.message.message) |
| 131 | + for res in stream_response: |
| 132 | + if res.WhichOneof('response') == 'eoe': |
| 133 | + _LOGGER.info("end execute %s, spent: %d" % (res.eoe.sql, res.eoe.spent_time_seconds)) |
| 134 | + compound_message.add_message('\n'.join(all_messages), res) |
| 135 | + break |
| 136 | + _LOGGER.info(res.message.message) |
| 137 | + all_messages.append(res.message.message) |
172 | 138 | else: |
173 | | - _LOGGER.info(first.message.message) |
174 | | - for res in stream_response: |
175 | | - _LOGGER.info(res.message.message) |
176 | | - else: |
177 | | - column_names = [column_name for column_name in first.head.column_names] |
178 | | - |
179 | | - def rows_gen(): |
180 | | - for res in stream_response: |
181 | | - yield [cls._decode_any(a) for a in res.row.data] |
182 | | - return Rows(column_names, rows_gen) |
| 139 | + column_names = [column_name for column_name in first.head.column_names] |
| 140 | + def rows_gen(): |
| 141 | + for res in stream_response: |
| 142 | + if res.WhichOneof('response') == 'eoe': |
| 143 | + _LOGGER.info("end execute %s, spent: %d" % (res.eoe.sql, res.eoe.spent_time_seconds)) |
| 144 | + break |
| 145 | + yield [cls._decode_any(a) for a in res.row.data] |
| 146 | + rows = Rows(column_names, rows_gen) |
| 147 | + _LOGGER.info(rows) |
| 148 | + compound_message.add_rows(rows, None) |
| 149 | + return compound_message |
183 | 150 |
|
184 | 151 | @classmethod |
185 | 152 | def _decode_any(cls, any_message): |
|
0 commit comments