Skip to content

Commit 6cdb3ac

Browse files
authored
Fix label s3 ingestion (#309)
1 parent 392f1cb commit 6cdb3ac

File tree

4 files changed

+88
-64
lines changed

4 files changed

+88
-64
lines changed

backend/deepchecks_monitoring/bgtasks/tasks_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ async def _run_task(self, task: Task, session, queued_timestamp, lock):
133133
else:
134134
self.logger.info({'message': f'Unknown task type: {task.bg_worker_task}'})
135135
except Exception: # pylint: disable=broad-except
136-
await session.rollback()
137136
self.logger.exception({'message': 'Exception running task', 'task': task.bg_worker_task})
137+
await session.rollback()
138138

139139

140140
class BaseWorkerSettings():

backend/deepchecks_monitoring/ee/bgtasks/object_storage_ingestor.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,9 @@ async def run(self, task: 'Task', session: AsyncSession, resources_provider: Res
146146
version_prefixes = model_prefixes if version.latest_file_time is not None else ['']
147147
for prefix in version_prefixes:
148148
for df, time in self.ingest_prefix(s3, bucket, f'{version_path}/{prefix}', version.latest_file_time,
149-
errors, version.model_id, version.id):
150-
# For each file, set lock expiry to 240 seconds from now
151-
await lock.extend(240, replace_ttl=True)
149+
errors, version.model_id, version.id, need_ts=True):
150+
# For each file, set lock expiry to 360 seconds from now
151+
await lock.extend(360, replace_ttl=True)
152152
await self.ingestion_backend.log_samples(version, df, session, organization_id, new_scan_time)
153153
version.latest_file_time = max(version.latest_file_time or
154154
pdl.datetime(year=1970, month=1, day=1), time)
@@ -158,8 +158,8 @@ async def run(self, task: 'Task', session: AsyncSession, resources_provider: Res
158158
labels_path = f'{model_path}/labels/{prefix}'
159159
for df, time in self.ingest_prefix(s3, bucket, labels_path, model.latest_labels_file_time,
160160
errors, model_id):
161-
# For each file, set lock expiry to 240 seconds from now
162-
await lock.extend(240, replace_ttl=True)
161+
# For each file, set lock expiry to 360 seconds from now
162+
await lock.extend(360, replace_ttl=True)
163163
await self.ingestion_backend.log_labels(model, df, session, organization_id)
164164
model.latest_labels_file_time = max(model.latest_labels_file_time
165165
or pdl.datetime(year=1970, month=1, day=1), time)
@@ -175,7 +175,8 @@ async def run(self, task: 'Task', session: AsyncSession, resources_provider: Res
175175
self.logger.info({'message': 'finished job', 'worker name': str(type(self)),
176176
'task': task.id, 'model_id': model_id, 'org_id': organization_id})
177177

178-
def ingest_prefix(self, s3, bucket, prefix, last_file_time, errors, model_id, version_id=None):
178+
def ingest_prefix(self, s3, bucket, prefix, last_file_time, errors,
179+
model_id, version_id=None, need_ts: bool = False):
179180
"""Ingest all files in prefix, return df and file time"""
180181
last_file_time = last_file_time or pdl.datetime(year=1970, month=1, day=1)
181182
# First read all file names, then retrieve them sorted by date
@@ -226,15 +227,15 @@ def ingest_prefix(self, s3, bucket, prefix, last_file_time, errors, model_id, ve
226227
self._handle_error(errors, f'Invalid file extension: {file["extension"]}, for file: {file["key"]}',
227228
model_id, version_id)
228229
continue
229-
230-
if SAMPLE_TS_COL not in df or not is_integer_dtype(df[SAMPLE_TS_COL]):
231-
self._handle_error(errors, f'Invalid timestamp column: {SAMPLE_TS_COL}, in file: {file["key"]}',
232-
model_id, version_id)
233-
continue
234-
# The user facing API requires unix timestamps, but for the ingestion we convert it to ISO format
235-
df[SAMPLE_TS_COL] = df[SAMPLE_TS_COL].apply(lambda x: pdl.from_timestamp(x).isoformat())
236-
# Sort by timestamp
237-
df = df.sort_values(by=[SAMPLE_TS_COL])
230+
if need_ts:
231+
if SAMPLE_TS_COL not in df or not is_integer_dtype(df[SAMPLE_TS_COL]):
232+
self._handle_error(errors, f'Invalid timestamp column: {SAMPLE_TS_COL}, in file: {file["key"]}',
233+
model_id, version_id)
234+
continue
235+
# The user facing API requires unix timestamps, but for the ingestion we convert it to ISO format
236+
df[SAMPLE_TS_COL] = df[SAMPLE_TS_COL].apply(lambda x: pdl.from_timestamp(x).isoformat())
237+
# Sort by timestamp
238+
df = df.sort_values(by=[SAMPLE_TS_COL])
238239
yield df, file['time']
239240

240241
def _handle_error(self, errors, error_message, model_id=None, model_version_id=None, set_warning_in_logs=True):

backend/deepchecks_monitoring/logic/data_ingestion.py

Lines changed: 70 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,8 @@ async def log_labels(
185185
cache_functions,
186186
logger
187187
):
188-
valid_data = {}
188+
unbatched_valid_data = pd.Series()
189+
logged_ids = set()
189190
labels_table_columns = model.get_sample_labels_columns()
190191
labels_table_json_schema = {
191192
"type": "object",
@@ -198,29 +199,51 @@ async def log_labels(
198199
}
199200

200201
validator = t.cast(t.Callable[..., t.Any], fastjsonschema.compile(labels_table_json_schema))
201-
202+
errors = []
202203
for sample in data:
203204
try:
204205
validator(sample)
205-
except fastjsonschema.JsonSchemaValueException:
206-
pass
207-
# TODO: new table for model ingestion errors?
206+
except fastjsonschema.JsonSchemaValueException as e:
207+
errors.append({
208+
"sample": str(sample),
209+
"sample_id": sample.get(SAMPLE_ID_COL),
210+
"error": f"Exception saving label: {str(e)}, for id: {sample.get(SAMPLE_ID_COL)}",
211+
"model_id": model.id,
212+
})
208213
else:
209-
error = None
210214
# If got same index more than once, log it as error
211-
if sample[SAMPLE_ID_COL] in valid_data:
212-
error = f"Got duplicate sample id: {sample[SAMPLE_ID_COL]}"
215+
if sample[SAMPLE_ID_COL] in logged_ids:
216+
errors.append({
217+
"sample": str(sample),
218+
"sample_id": sample.get(SAMPLE_ID_COL),
219+
"error": f"Got duplicate label for sample id: {sample[SAMPLE_ID_COL]}. "
220+
f"{sample.get(SAMPLE_LABEL_COL)} vs "
221+
f"{unbatched_valid_data[sample[SAMPLE_ID_COL]].get(SAMPLE_LABEL_COL)}",
222+
"model_id": model.id,
223+
})
224+
else:
225+
unbatched_valid_data[sample[SAMPLE_ID_COL]] = sample
226+
logged_ids.add(sample[SAMPLE_ID_COL])
213227

214-
if not error:
215-
valid_data[sample[SAMPLE_ID_COL]] = sample
228+
await save_failures(session, errors, logger)
216229

217-
if valid_data:
230+
if len(unbatched_valid_data) == 0:
231+
return
232+
max_messages_per_insert = QUERY_PARAM_LIMIT // 5
233+
ids_to_log = unbatched_valid_data.keys()
234+
for start_index in range(0, len(ids_to_log), max_messages_per_insert):
235+
valid_data = unbatched_valid_data[ids_to_log[start_index:start_index + max_messages_per_insert]]
218236
# Query from the ids mapping all the relevant versions per each version. This is needed in order to query
219237
# the timestamps to invalidate the monitors cache
220238
versions_table = model.get_samples_versions_map_table(session)
221-
versions_select = (select(versions_table.c["version_id"], array_agg(versions_table.c[SAMPLE_ID_COL]))
222-
.where(versions_table.c[SAMPLE_ID_COL].in_(list(valid_data.keys())))
223-
.group_by(versions_table.c["version_id"]))
239+
versions_select = (
240+
select(
241+
versions_table.c["version_id"],
242+
array_agg(versions_table.c[SAMPLE_ID_COL])
243+
)
244+
.where(versions_table.c[SAMPLE_ID_COL].in_(list(valid_data.keys())))
245+
.group_by(versions_table.c["version_id"])
246+
)
224247
results = (await session.execute(versions_select)).all()
225248

226249
# Validation of classes amount for binary tasks
@@ -245,39 +268,39 @@ async def log_labels(
245268
del valid_data[sample_id]
246269
await save_failures(session, errors, logger)
247270

248-
if valid_data:
249-
# update label statistics
250-
for row in results:
251-
version_id = row[0]
252-
sample_ids = [sample_id for sample_id in row[1] if sample_id in valid_data]
253-
model_version: ModelVersion = \
254-
(await session.execute(select(ModelVersion).where(ModelVersion.id == version_id))).scalars().first()
255-
updated_statistics = copy.deepcopy(model_version.statistics)
256-
for sample_id in sample_ids:
257-
update_statistics_from_sample(updated_statistics, valid_data[sample_id])
258-
if model_version.statistics != updated_statistics:
259-
await model_version.update_statistics(updated_statistics, session)
260-
261-
# Insert or update all labels
262-
labels_table = model.get_sample_labels_table(session)
263-
insert_statement = postgresql.insert(labels_table)
264-
upsert_statement = insert_statement.on_conflict_do_update(
265-
index_elements=[SAMPLE_ID_COL],
266-
set_={SAMPLE_LABEL_COL: insert_statement.excluded[SAMPLE_LABEL_COL]}
267-
)
268-
await session.execute(upsert_statement, list(valid_data.values()))
269-
270-
for row in results:
271-
version_id = row[0]
272-
sample_ids = [sample_id for sample_id in row[1] if sample_id in valid_data]
273-
monitor_table_name = get_monitor_table_name(model.id, version_id)
274-
ts_select = (select(Column(SAMPLE_TS_COL))
275-
.select_from(text(monitor_table_name))
276-
.where(Column(SAMPLE_ID_COL).in_(sample_ids)))
277-
timestamps_affected = [pdl.instance(x) for x in (await session.execute(ts_select)).scalars()]
278-
await add_cache_invalidation(org_id, version_id, timestamps_affected, session, cache_functions)
279-
280-
model.last_update_time = pdl.now()
271+
if len(valid_data) > 0:
272+
# update label statistics
273+
for row in results:
274+
version_id = row[0]
275+
sample_ids = [sample_id for sample_id in row[1] if sample_id in valid_data]
276+
model_version: ModelVersion = \
277+
(await session.execute(select(ModelVersion).where(ModelVersion.id == version_id))).scalars().first()
278+
updated_statistics = copy.deepcopy(model_version.statistics)
279+
for sample_id in sample_ids:
280+
update_statistics_from_sample(updated_statistics, valid_data[sample_id])
281+
if model_version.statistics != updated_statistics:
282+
await model_version.update_statistics(updated_statistics, session)
283+
284+
# Insert or update all labels
285+
labels_table = model.get_sample_labels_table(session)
286+
insert_statement = postgresql.insert(labels_table)
287+
upsert_statement = insert_statement.on_conflict_do_update(
288+
index_elements=[SAMPLE_ID_COL],
289+
set_={SAMPLE_LABEL_COL: insert_statement.excluded[SAMPLE_LABEL_COL]}
290+
)
291+
await session.execute(upsert_statement, valid_data.tolist())
292+
293+
for row in results:
294+
version_id = row[0]
295+
sample_ids = [sample_id for sample_id in row[1] if sample_id in valid_data]
296+
monitor_table_name = get_monitor_table_name(model.id, version_id)
297+
ts_select = (select(Column(SAMPLE_TS_COL))
298+
.select_from(text(monitor_table_name))
299+
.where(Column(SAMPLE_ID_COL).in_(sample_ids)))
300+
timestamps_affected = [pdl.instance(x) for x in (await session.execute(ts_select)).scalars()]
301+
await add_cache_invalidation(org_id, version_id, timestamps_affected, session, cache_functions)
302+
303+
model.last_update_time = pdl.now()
281304

282305

283306
async def add_cache_invalidation(organization_id, model_version_id, timestamps_updated, session, cache_functions):

backend/tests/api/test_data_input.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ async def test_log_labels_non_existing_samples(
225225
model_id=classification_model["id"],
226226
data=[{
227227
"_dc_sample_id": "not exists",
228-
"c": 0
228+
"_dc_label": "0"
229229
}]
230230
)
231231
# Assert

0 commit comments

Comments
 (0)