Skip to content

Commit 9dd5783

Browse files
Avoid Fields to have the same name as existing attributes in the Table.
1 parent f4c55dd commit 9dd5783

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

src/Core/Storage/Database.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Union, List, Type, Dict, Tuple, Optional, Any, Callable
22
from os import remove, mkdir
33
from os.path import exists, join, sep, getsize
4+
from inspect import getmembers
45
from playhouse.migrate import SqliteDatabase
56
from playhouse.signals import Signal, pre_save, post_save
67
from datetime import datetime
@@ -255,6 +256,11 @@ def __new_fields(self,
255256
field_name, field_type = field[0], field[1]
256257
field_default = '_null_' if len(field) == 2 else field[2]
257258

259+
# As peewee.Model creates a new attribute named field_name, check that this attribute does not exist
260+
if field_name in [m[0] for m in getmembers(table)]:
261+
raise ValueError(f"Tried to create a field '{field_name}' in the Table '{table_name}'. "
262+
f"You are not allowed to create a field with this name, please rename it.")
263+
258264
# Extend the Table
259265
if field_name not in table.fields():
260266
# FK
@@ -353,10 +359,11 @@ def add_batch(self,
353359

354360
table_name = self.make_name(table_name)
355361
# Check that the batch is well-formed
356-
batch_values = [batch[key] for key in set(batch.keys()) - set(self.__fk[table_name])]
357-
if len(unique(samples := [len(b) for b in batch_values])) != 1:
358-
raise ValueError(f"The number of samples per batch must be the same for all fields. Number of samples "
359-
f"received per field: {dict(zip(batch.keys(), samples))}")
362+
if table_name in self.__fk:
363+
batch_values = [batch[key] for key in set(batch.keys()) - set(self.__fk[table_name])]
364+
if len(unique(samples := [len(b) for b in batch_values])) != 1:
365+
raise ValueError(f"The number of samples per batch must be the same for all fields. Number of samples "
366+
f"received per field: {dict(zip(batch.keys(), samples))}")
360367
self.__add_data(table_name=table_name,
361368
data=batch,
362369
batched=True)
@@ -371,7 +378,7 @@ def __add_data(self,
371378
fields_values = list(data.values())
372379
fields_types = []
373380
for name, value in zip(fields_names, fields_values):
374-
if name in self.__fk[table_name]:
381+
if table_name in self.__fk and name in self.__fk[table_name]:
375382
fields_types.append(self.__fk[table_name][name])
376383
elif batched:
377384
fields_types.append(type(value[0]))
@@ -571,8 +578,8 @@ def get_lines(self,
571578
if lines_range is not None and len(lines_range) != 2:
572579
raise ValueError("The range of lines must contains the first and the last line indices.")
573580
nb_line = self.nb_lines(table_name=table_name)
574-
first_line_id = 1 if lines_range is None else lines_range[0]
575-
last_line_id = nb_line if lines_range is None else lines_range[1]
581+
first_line_id = lines_range[0] if lines_range is not None else 1
582+
last_line_id = lines_range[1] if lines_range is not None else nb_line
576583
_slice = [first_line_id, last_line_id]
577584
for i, idx in enumerate(_slice):
578585
if idx < 0:

0 commit comments

Comments
 (0)