11from typing import Union , List , Type , Dict , Tuple , Optional , Any , Callable
22from os import remove , mkdir
33from os .path import exists , join , sep , getsize
4+ from inspect import getmembers
45from playhouse .migrate import SqliteDatabase
56from playhouse .signals import Signal , pre_save , post_save
67from datetime import datetime
@@ -255,6 +256,11 @@ def __new_fields(self,
255256 field_name , field_type = field [0 ], field [1 ]
256257 field_default = '_null_' if len (field ) == 2 else field [2 ]
257258
259+ # As peewee.Model creates a new attribute named field_name, check that this attribute does not exist
260+ if field_name in [m [0 ] for m in getmembers (table )]:
261+ raise ValueError (f"Tried to create a field '{ field_name } ' in the Table '{ table_name } '. "
262+ f"You are not allowed to create a field with this name, please rename it." )
263+
258264 # Extend the Table
259265 if field_name not in table .fields ():
260266 # FK
@@ -353,10 +359,11 @@ def add_batch(self,
353359
354360 table_name = self .make_name (table_name )
355361 # Check that the batch is well-formed
356- batch_values = [batch [key ] for key in set (batch .keys ()) - set (self .__fk [table_name ])]
357- if len (unique (samples := [len (b ) for b in batch_values ])) != 1 :
358- raise ValueError (f"The number of samples per batch must be the same for all fields. Number of samples "
359- f"received per field: { dict (zip (batch .keys (), samples ))} " )
362+ if table_name in self .__fk :
363+ batch_values = [batch [key ] for key in set (batch .keys ()) - set (self .__fk [table_name ])]
364+ if len (unique (samples := [len (b ) for b in batch_values ])) != 1 :
365+ raise ValueError (f"The number of samples per batch must be the same for all fields. Number of samples "
366+ f"received per field: { dict (zip (batch .keys (), samples ))} " )
360367 self .__add_data (table_name = table_name ,
361368 data = batch ,
362369 batched = True )
@@ -371,7 +378,7 @@ def __add_data(self,
371378 fields_values = list (data .values ())
372379 fields_types = []
373380 for name , value in zip (fields_names , fields_values ):
374- if name in self .__fk [table_name ]:
381+ if table_name in self . __fk and name in self .__fk [table_name ]:
375382 fields_types .append (self .__fk [table_name ][name ])
376383 elif batched :
377384 fields_types .append (type (value [0 ]))
@@ -571,8 +578,8 @@ def get_lines(self,
571578 if lines_range is not None and len (lines_range ) != 2 :
572579 raise ValueError ("The range of lines must contains the first and the last line indices." )
573580 nb_line = self .nb_lines (table_name = table_name )
574- first_line_id = 1 if lines_range is None else lines_range [ 0 ]
575- last_line_id = nb_line if lines_range is None else lines_range [ 1 ]
581+ first_line_id = lines_range [ 0 ] if lines_range is not None else 1
582+ last_line_id = lines_range [ 1 ] if lines_range is not None else nb_line
576583 _slice = [first_line_id , last_line_id ]
577584 for i , idx in enumerate (_slice ):
578585 if idx < 0 :
0 commit comments