@@ -238,13 +238,13 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
238238
239239 TF_ImportGraphDefOptions * options = TF_NewImportGraphDefOptions ();
240240
241- TF_Buffer * buffer = TF_NewBuffer ();
242- buffer -> length = modellen ;
243- buffer -> data = modeldef ;
241+ TF_Buffer * tfbuffer = TF_NewBuffer ();
242+ tfbuffer -> length = modellen ;
243+ tfbuffer -> data = modeldef ;
244244
245245 TF_Status * status = TF_NewStatus ();
246246
247- TF_GraphImportGraphDef (model , buffer , options , status );
247+ TF_GraphImportGraphDef (model , tfbuffer , options , status );
248248
249249 if (TF_GetCode (status ) != TF_OK ) {
250250 char * errorMessage = RedisModule_Strdup (TF_Message (status ));
@@ -276,7 +276,7 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
276276 }
277277
278278 TF_DeleteImportGraphDefOptions (options );
279- TF_DeleteBuffer (buffer );
279+ TF_DeleteBuffer (tfbuffer );
280280 TF_DeleteStatus (status );
281281
282282 TF_Status * optionsStatus = TF_NewStatus ();
@@ -394,6 +394,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
394394 array_append (outputs_ , RedisModule_Strdup (outputs [i ]));
395395 }
396396
397+ char * buffer = RedisModule_Calloc (modellen , sizeof (* buffer ));
398+ memcpy (buffer , modeldef , modellen );
399+
397400 RAI_Model * ret = RedisModule_Calloc (1 , sizeof (* ret ));
398401 ret -> model = model ;
399402 ret -> session = session ;
@@ -403,7 +406,9 @@ RAI_Model *RAI_ModelCreateTF(RAI_Backend backend, const char* devicestr, RAI_Mod
403406 ret -> outputs = outputs_ ;
404407 ret -> opts = opts ;
405408 ret -> refCount = 1 ;
406-
409+ ret -> data = buffer ;
410+ ret -> datalen = modellen ;
411+
407412 return ret ;
408413}
409414
@@ -445,6 +450,10 @@ void RAI_ModelFreeTF(RAI_Model* model, RAI_Error* error) {
445450 array_free (model -> outputs );
446451 }
447452
453+ if (model -> data ) {
454+ RedisModule_Free (model -> data );
455+ }
456+
448457 TF_DeleteStatus (status );
449458}
450459
@@ -534,24 +543,32 @@ int RAI_ModelRunTF(RAI_ModelRunCtx** mctxs, RAI_Error *error) {
534543}
535544
536545int RAI_ModelSerializeTF (RAI_Model * model , char * * buffer , size_t * len , RAI_Error * error ) {
537- TF_Buffer * tf_buffer = TF_NewBuffer ();
538- TF_Status * status = TF_NewStatus ();
539546
540- TF_GraphToGraphDef (model -> model , tf_buffer , status );
547+ if (model -> data ) {
548+ * buffer = RedisModule_Calloc (model -> datalen , sizeof (char ));
549+ memcpy (* buffer , model -> data , model -> datalen );
550+ * len = model -> datalen ;
551+ }
552+ else {
553+ TF_Buffer * tf_buffer = TF_NewBuffer ();
554+ TF_Status * status = TF_NewStatus ();
555+
556+ TF_GraphToGraphDef (model -> model , tf_buffer , status );
557+
558+ if (TF_GetCode (status ) != TF_OK ) {
559+ RAI_SetError (error , RAI_EMODELSERIALIZE , "ERR Error serializing TF model" );
560+ TF_DeleteBuffer (tf_buffer );
561+ TF_DeleteStatus (status );
562+ return 1 ;
563+ }
564+
565+ * buffer = RedisModule_Alloc (tf_buffer -> length );
566+ memcpy (* buffer , tf_buffer -> data , tf_buffer -> length );
567+ * len = tf_buffer -> length ;
541568
542- if (TF_GetCode (status ) != TF_OK ) {
543- RAI_SetError (error , RAI_EMODELSERIALIZE , "ERR Error serializing TF model" );
544569 TF_DeleteBuffer (tf_buffer );
545570 TF_DeleteStatus (status );
546- return 1 ;
547571 }
548572
549- * buffer = RedisModule_Alloc (tf_buffer -> length );
550- memcpy (* buffer , tf_buffer -> data , tf_buffer -> length );
551- * len = tf_buffer -> length ;
552-
553- TF_DeleteBuffer (tf_buffer );
554- TF_DeleteStatus (status );
555-
556573 return 0 ;
557574}
0 commit comments