Skip to content

Commit 54df036

Browse files
The return of 'add_data' and 'add_batch' is more clear.
1 parent 9dd5783 commit 54df036

File tree

4 files changed

+19
-16
lines changed

4 files changed

+19
-16
lines changed

docs/src/Core/Storage/database.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ Adding data to a Table
117117
----------------------
118118

119119
Adding data to a *Table* can be done either line by line either per batch of lines.
120-
In both cases, data must be passed as a dictionary:
120+
In both cases, data must be passed as a dictionary and the index of the created line(s) are returned:
121121

122122
.. code-block:: python
123123

docs/src/Core/Storage/relationships.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ Data can still be added either line by line either per batch of lines:
4444
new_array = np.random.uniform(size=(10,))
4545
4646
# Add a single line to Tables independently
47-
stats = db.add_data(table_name='Stats',
48-
data={'mean': new_array.mean(),
49-
'max': new_array.max()})
47+
id_stats_line = db.add_data(table_name='Stats',
48+
data={'mean': new_array.mean(),
49+
'max': new_array.max()})
5050
db.add_data(table_name='Arrays',
5151
data={'array': new_array,
52-
'stats': stats})
52+
'stats': id_stats_line})
5353
5454
# Add a single line to Tables using the main Table
5555
db.add_data(table_name='Arrays',
@@ -62,12 +62,12 @@ Data can still be added either line by line either per batch of lines:
6262
new_arrays = [np.random.uniform(size=(10,)) for _ in range(5)]
6363
6464
# Add a batch to Tables independently
65-
stats = [db.add_data(table_name='Stats',
66-
data={'mean': new_array.mean(),
67-
'max': new_array.max()}) for new_array in new_arrays]
65+
id_stats_lines = db.add_batch(table_name='Stats',
66+
data={'mean': [new_array.mean() for new_array in new_arrays],
67+
'max': [new_array.max() for new_array in new_arrays]})
6868
db.add_batch(table_name='Arrays',
6969
batch={'array': new_arrays,
70-
'stats': stats})
70+
'stats': id_stats_lines})
7171
7272
# Add a batch to Tables using th main Table
7373
db.add_batch(table_name='Arrays',

src/Core/Storage/AdaptiveTable.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def add_data(cls,
149149
for chunk in chunked(batch, 100):
150150
cls.insert_many(chunk, fields=fields).execute()
151151
N = cls.select().count()
152-
return [cls.get_by_id(i + 1) for i in range(n, N)]
152+
return [cls.get_by_id(i + 1).id for i in range(n, N)]
153153

154154

155155
class ExchangeTable(AdaptiveTable):
@@ -165,14 +165,17 @@ def add_data(cls,
165165
cls.delete().execute()
166166
line = cls(**dict(zip(fields_names, fields_values)))
167167
line.save()
168-
return line
168+
return line.id
169169

170170
else:
171171
fields = [getattr(cls, field) for field in fields_names]
172172
batch = [tuple(samples) for samples in zip(*fields_values)]
173173
cls.delete().execute()
174174
pre_save.send(cls, created=False)
175+
n = cls.select().count()
175176
with cls.database().atomic():
176177
for chunk in chunked(batch, 100):
177178
cls.insert_many(chunk, fields=fields).execute()
179+
N = cls.select().count()
178180
post_save.send(cls, created=False)
181+
return [cls.get_by_id(i + 1).id for i in range(n, N)]

src/Core/Storage/Database.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ def add_data(self,
337337
table_name: str,
338338
data: Dict[str, Any]):
339339
"""
340-
Execute a line insert query.
340+
Execute a line insert query. Return the index of the new line in the Table.
341341
342342
:param table_name: Name of the Table.
343343
:param data: New line of the Table.
@@ -351,7 +351,7 @@ def add_batch(self,
351351
table_name: str,
352352
batch: Dict[str, List[Any]]):
353353
"""
354-
Execute a batch insert query.
354+
Execute a batch insert query. Return the indices of the new lines in the Table.
355355
356356
:param table_name: Name of the Table.
357357
:param batch: New lines of the Table.
@@ -364,9 +364,9 @@ def add_batch(self,
364364
if len(unique(samples := [len(b) for b in batch_values])) != 1:
365365
raise ValueError(f"The number of samples per batch must be the same for all fields. Number of samples "
366366
f"received per field: {dict(zip(batch.keys(), samples))}")
367-
self.__add_data(table_name=table_name,
368-
data=batch,
369-
batched=True)
367+
return self.__add_data(table_name=table_name,
368+
data=batch,
369+
batched=True)
370370

371371
def __add_data(self,
372372
table_name: str,

0 commit comments

Comments
 (0)