@@ -185,7 +185,8 @@ async def log_labels(
185185 cache_functions ,
186186 logger
187187):
188- valid_data = {}
188+ unbatched_valid_data = pd .Series ()
189+ logged_ids = set ()
189190 labels_table_columns = model .get_sample_labels_columns ()
190191 labels_table_json_schema = {
191192 "type" : "object" ,
@@ -198,29 +199,51 @@ async def log_labels(
198199 }
199200
200201 validator = t .cast (t .Callable [..., t .Any ], fastjsonschema .compile (labels_table_json_schema ))
201-
202+ errors = []
202203 for sample in data :
203204 try :
204205 validator (sample )
205- except fastjsonschema .JsonSchemaValueException :
206- pass
207- # TODO: new table for model ingestion errors?
206+ except fastjsonschema .JsonSchemaValueException as e :
207+ errors .append ({
208+ "sample" : str (sample ),
209+ "sample_id" : sample .get (SAMPLE_ID_COL ),
210+ "error" : f"Exception saving label: { str (e )} , for id: { sample .get (SAMPLE_ID_COL )} " ,
211+ "model_id" : model .id ,
212+ })
208213 else :
209- error = None
210214 # If got same index more than once, log it as error
211- if sample [SAMPLE_ID_COL ] in valid_data :
212- error = f"Got duplicate sample id: { sample [SAMPLE_ID_COL ]} "
215+ if sample [SAMPLE_ID_COL ] in logged_ids :
216+ errors .append ({
217+ "sample" : str (sample ),
218+ "sample_id" : sample .get (SAMPLE_ID_COL ),
219+ "error" : f"Got duplicate label for sample id: { sample [SAMPLE_ID_COL ]} . "
220+ f"{ sample .get (SAMPLE_LABEL_COL )} vs "
221+ f"{ unbatched_valid_data [sample [SAMPLE_ID_COL ]].get (SAMPLE_LABEL_COL )} " ,
222+ "model_id" : model .id ,
223+ })
224+ else :
225+ unbatched_valid_data [sample [SAMPLE_ID_COL ]] = sample
226+ logged_ids .add (sample [SAMPLE_ID_COL ])
213227
214- if not error :
215- valid_data [sample [SAMPLE_ID_COL ]] = sample
228+ await save_failures (session , errors , logger )
216229
217- if valid_data :
230+ if len (unbatched_valid_data ) == 0 :
231+ return
232+ max_messages_per_insert = QUERY_PARAM_LIMIT // 5
233+ ids_to_log = unbatched_valid_data .keys ()
234+ for start_index in range (0 , len (ids_to_log ), max_messages_per_insert ):
235+ valid_data = unbatched_valid_data [ids_to_log [start_index :start_index + max_messages_per_insert ]]
218236 # Query from the ids mapping all the relevant versions per each version. This is needed in order to query
219237 # the timestamps to invalidate the monitors cache
220238 versions_table = model .get_samples_versions_map_table (session )
221- versions_select = (select (versions_table .c ["version_id" ], array_agg (versions_table .c [SAMPLE_ID_COL ]))
222- .where (versions_table .c [SAMPLE_ID_COL ].in_ (list (valid_data .keys ())))
223- .group_by (versions_table .c ["version_id" ]))
239+ versions_select = (
240+ select (
241+ versions_table .c ["version_id" ],
242+ array_agg (versions_table .c [SAMPLE_ID_COL ])
243+ )
244+ .where (versions_table .c [SAMPLE_ID_COL ].in_ (list (valid_data .keys ())))
245+ .group_by (versions_table .c ["version_id" ])
246+ )
224247 results = (await session .execute (versions_select )).all ()
225248
226249 # Validation of classes amount for binary tasks
@@ -245,39 +268,39 @@ async def log_labels(
245268 del valid_data [sample_id ]
246269 await save_failures (session , errors , logger )
247270
248- if valid_data :
249- # update label statistics
250- for row in results :
251- version_id = row [0 ]
252- sample_ids = [sample_id for sample_id in row [1 ] if sample_id in valid_data ]
253- model_version : ModelVersion = \
254- (await session .execute (select (ModelVersion ).where (ModelVersion .id == version_id ))).scalars ().first ()
255- updated_statistics = copy .deepcopy (model_version .statistics )
256- for sample_id in sample_ids :
257- update_statistics_from_sample (updated_statistics , valid_data [sample_id ])
258- if model_version .statistics != updated_statistics :
259- await model_version .update_statistics (updated_statistics , session )
260-
261- # Insert or update all labels
262- labels_table = model .get_sample_labels_table (session )
263- insert_statement = postgresql .insert (labels_table )
264- upsert_statement = insert_statement .on_conflict_do_update (
265- index_elements = [SAMPLE_ID_COL ],
266- set_ = {SAMPLE_LABEL_COL : insert_statement .excluded [SAMPLE_LABEL_COL ]}
267- )
268- await session .execute (upsert_statement , list ( valid_data .values () ))
269-
270- for row in results :
271- version_id = row [0 ]
272- sample_ids = [sample_id for sample_id in row [1 ] if sample_id in valid_data ]
273- monitor_table_name = get_monitor_table_name (model .id , version_id )
274- ts_select = (select (Column (SAMPLE_TS_COL ))
275- .select_from (text (monitor_table_name ))
276- .where (Column (SAMPLE_ID_COL ).in_ (sample_ids )))
277- timestamps_affected = [pdl .instance (x ) for x in (await session .execute (ts_select )).scalars ()]
278- await add_cache_invalidation (org_id , version_id , timestamps_affected , session , cache_functions )
279-
280- model .last_update_time = pdl .now ()
271+ if len ( valid_data ) > 0 :
272+ # update label statistics
273+ for row in results :
274+ version_id = row [0 ]
275+ sample_ids = [sample_id for sample_id in row [1 ] if sample_id in valid_data ]
276+ model_version : ModelVersion = \
277+ (await session .execute (select (ModelVersion ).where (ModelVersion .id == version_id ))).scalars ().first ()
278+ updated_statistics = copy .deepcopy (model_version .statistics )
279+ for sample_id in sample_ids :
280+ update_statistics_from_sample (updated_statistics , valid_data [sample_id ])
281+ if model_version .statistics != updated_statistics :
282+ await model_version .update_statistics (updated_statistics , session )
283+
284+ # Insert or update all labels
285+ labels_table = model .get_sample_labels_table (session )
286+ insert_statement = postgresql .insert (labels_table )
287+ upsert_statement = insert_statement .on_conflict_do_update (
288+ index_elements = [SAMPLE_ID_COL ],
289+ set_ = {SAMPLE_LABEL_COL : insert_statement .excluded [SAMPLE_LABEL_COL ]}
290+ )
291+ await session .execute (upsert_statement , valid_data .tolist ( ))
292+
293+ for row in results :
294+ version_id = row [0 ]
295+ sample_ids = [sample_id for sample_id in row [1 ] if sample_id in valid_data ]
296+ monitor_table_name = get_monitor_table_name (model .id , version_id )
297+ ts_select = (select (Column (SAMPLE_TS_COL ))
298+ .select_from (text (monitor_table_name ))
299+ .where (Column (SAMPLE_ID_COL ).in_ (sample_ids )))
300+ timestamps_affected = [pdl .instance (x ) for x in (await session .execute (ts_select )).scalars ()]
301+ await add_cache_invalidation (org_id , version_id , timestamps_affected , session , cache_functions )
302+
303+ model .last_update_time = pdl .now ()
281304
282305
283306async def add_cache_invalidation (organization_id , model_version_id , timestamps_updated , session , cache_functions ):
0 commit comments