Skip to content

Commit 4f2e4c0

Browse files
exteded testsuite with test that deterministically crashes redisai (issue #323) (#324)
* [add] added test that deterministically crashes redisai (issue #323) * Handle the case of batch size 1 minimally Co-authored-by: Luca Antiga <luca.antiga@orobix.com>
1 parent 2cbfca8 commit 4f2e4c0

File tree

2 files changed

+75
-15
lines changed

2 files changed

+75
-15
lines changed

src/redisai.c

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ void RedisAI_FreeRunStats(RedisModuleCtx *ctx, struct RedisAI_RunStats *rstats)
205205
}
206206

207207
void *RedisAI_RunSession(struct RedisAI_RunInfo **batch_rinfo) {
208-
if (array_len(batch_rinfo) == 0) {
208+
209+
const long long batch_size = array_len(batch_rinfo);
210+
211+
if (batch_size == 0) {
209212
return NULL;
210213
}
211214

@@ -215,10 +218,15 @@ void *RedisAI_RunSession(struct RedisAI_RunInfo **batch_rinfo) {
215218
RAI_ModelRunCtx* mctx = NULL;
216219
RAI_ScriptRunCtx* sctx = NULL;
217220
if (batch_rinfo[0]->mctx) {
218-
mctx = RAI_ModelRunCtxCreate(batch_rinfo[0]->mctx->model);
219-
for (long long i=0; i<array_len(batch_rinfo); i++) {
220-
int id = RAI_ModelRunCtxAddBatch(mctx);
221-
RAI_ModelRunCtxCopyBatch(mctx, id, batch_rinfo[i]->mctx, 0);
221+
if (batch_size > 1) {
222+
mctx = RAI_ModelRunCtxCreate(batch_rinfo[0]->mctx->model);
223+
for (long long i=0; i<batch_size; i++) {
224+
int id = RAI_ModelRunCtxAddBatch(mctx);
225+
RAI_ModelRunCtxCopyBatch(mctx, id, batch_rinfo[i]->mctx, 0);
226+
}
227+
}
228+
else {
229+
mctx = batch_rinfo[0]->mctx;
222230
}
223231
}
224232
else if (batch_rinfo[0]->sctx) {
@@ -235,19 +243,24 @@ void *RedisAI_RunSession(struct RedisAI_RunInfo **batch_rinfo) {
235243
}
236244
rtime = ustime() - start;
237245

238-
for (long long i=0; i<array_len(batch_rinfo); i++) {
246+
for (long long i=0; i<batch_size; i++) {
239247
struct RedisAI_RunInfo *rinfo = batch_rinfo[i];
240248
if (mctx) {
241-
size_t noutputs = RAI_ModelRunCtxNumOutputs(mctx);
242-
for (long long o=0; o<noutputs; o++) {
243-
RAI_Tensor* tensor = mctx->batches[i].outputs[o].tensor;
244-
if (tensor) {
245-
rinfo->mctx->batches[0].outputs[o].tensor = RAI_TensorGetShallowCopy(tensor);
246-
}
247-
else {
248-
rinfo->mctx->batches[0].outputs[o].tensor = NULL;
249+
if (batch_size > 1) {
250+
size_t noutputs = RAI_ModelRunCtxNumOutputs(mctx);
251+
for (long long o=0; o<noutputs; o++) {
252+
RAI_Tensor* tensor = mctx->batches[i].outputs[o].tensor;
253+
if (tensor) {
254+
rinfo->mctx->batches[0].outputs[o].tensor = RAI_TensorGetShallowCopy(tensor);
255+
}
256+
else {
257+
rinfo->mctx->batches[0].outputs[o].tensor = NULL;
258+
}
249259
}
250260
}
261+
else {
262+
// Do nothing if no batching
263+
}
251264
}
252265
else if (sctx) {
253266
// No batching for scripts for now
@@ -270,7 +283,12 @@ void *RedisAI_RunSession(struct RedisAI_RunInfo **batch_rinfo) {
270283
}
271284

272285
if (mctx) {
273-
RAI_ModelRunCtxFree(mctx);
286+
if (batch_size > 1) {
287+
RAI_ModelRunCtxFree(mctx);
288+
}
289+
else {
290+
// Do nothing if no batching
291+
}
274292
}
275293
else if (sctx) {
276294
// No batching for scripts for now

test/tests_tensorflow.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,3 +708,45 @@ def test_tensorflow_modelrun_financialNet(env):
708708
'referenceTensor:{}'.format(tensor_number), 'OUTPUTS',
709709
'classificationTensor:{}_{}'.format(tensor_number, repetition))
710710
env.assertEqual(ret, b'OK')
711+
712+
713+
def test_tensorflow_modelrun_financialNet_multiproc(env):
714+
con = env.getConnection()
715+
716+
model_pb, creditcard_transactions, creditcard_referencedata = load_creditcardfraud_data(env)
717+
718+
tensor_number = 1
719+
for transaction_tensor in creditcard_transactions:
720+
ret = con.execute_command('AI.TENSORSET', 'transactionTensor:{0}'.format(tensor_number),
721+
'FLOAT', 1, 30,
722+
'BLOB', transaction_tensor.tobytes())
723+
env.assertEqual(ret, b'OK')
724+
tensor_number = tensor_number + 1
725+
726+
tensor_number = 1
727+
for reference_tensor in creditcard_referencedata:
728+
ret = con.execute_command('AI.TENSORSET', 'referenceTensor:{0}'.format(tensor_number),
729+
'FLOAT', 1, 256,
730+
'BLOB', reference_tensor.tobytes())
731+
env.assertEqual(ret, b'OK')
732+
tensor_number = tensor_number + 1
733+
734+
ret = con.execute_command('AI.MODELSET', 'financialNet', 'TF', "CPU",
735+
'INPUTS', 'transaction', 'reference', 'OUTPUTS', 'output', model_pb)
736+
env.assertEqual(ret, b'OK')
737+
738+
def functor_financialNet(env, key_max, repetitions):
739+
for tensor_number in range(1, key_max):
740+
for repetition in range(1, repetitions):
741+
ret = env.execute_command('AI.MODELRUN', 'financialNet', 'INPUTS',
742+
'transactionTensor:{}'.format(tensor_number),
743+
'referenceTensor:{}'.format(tensor_number), 'OUTPUTS',
744+
'classificationTensor:{}_{}'.format(tensor_number, repetition))
745+
746+
t = time.time()
747+
run_test_multiproc(env, 10,
748+
lambda env: functor_financialNet(env,len(transaction_tensor),100) )
749+
elapsed_time = time.time() - t
750+
total_ops = len(transaction_tensor)*100
751+
avg_ops_sec = total_ops/elapsed_time
752+
env.debugPrint("AI.MODELRUN elapsed time(sec) {:6.2f}\tTotal ops {:10.2f}\tAvg. ops/sec {:10.2f}".format(elapsed_time, total_ops, avg_ops_sec), True)

0 commit comments

Comments
 (0)