@@ -208,22 +208,13 @@ enum RedisAI_DataFmt {
208208 REDISAI_DATA_NONE
209209};
210210
211- // ================================
212-
213- // key type dim1..dimN [BLOB data | VALUES val1..valN]
211+ /*
212+ * AI.TENSORSET key type dim1..dimN [BLOB data | VALUES val1..valN]
213+ */
214214int RedisAI_TensorSet_RedisCommand (RedisModuleCtx * ctx , RedisModuleString * * argv , int argc ) {
215- RedisModule_AutoMemory (ctx );
216-
217215 if (argc < 4 ) return RedisModule_WrongArity (ctx );
218216
219- ArgsCursor ac ;
220- ArgsCursor_InitRString (& ac , argv + 1 , argc - 1 );
221-
222- RedisModuleString * keystr ;
223- AC_GetRString (& ac , & keystr , 0 );
224-
225- RedisModuleKey * key = RedisModule_OpenKey (ctx , keystr ,
226- REDISMODULE_READ |REDISMODULE_WRITE );
217+ RedisModuleKey * key = RedisModule_OpenKey (ctx , argv [1 ], REDISMODULE_READ |REDISMODULE_WRITE );
227218 const int type = RedisModule_KeyType (key );
228219 if (type != REDISMODULE_KEYTYPE_EMPTY &&
229220 !(type == REDISMODULE_KEYTYPE_MODULE &&
@@ -232,122 +223,128 @@ int RedisAI_TensorSet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
232223 return RedisModule_ReplyWithError (ctx , REDISMODULE_ERRORMSG_WRONGTYPE );
233224 }
234225
235- // getting the datatype
236- const char * typestr ;
237- AC_GetString (& ac , & typestr , NULL , 0 );
238-
226+ // get the tensor datatype
227+ const char * typestr = RedisModule_StringPtrLen (argv [2 ], NULL );
239228 size_t datasize = RAI_TensorDataSizeFromString (typestr );
240229 if (!datasize ){
241230 return RedisModule_ReplyWithError (ctx , "ERR invalid data type" );
242231 }
243232
244- int dims_arg = 3 ;
245-
246- ArgsCursor dac ;
247- const char * matches [] = {"BLOB" , "VALUES" };
248- AC_GetSliceUntilMatches (& ac , & dac , 2 , matches );
249-
250- size_t ndims = dac .argc ;
251- size_t len = 1 ;
252- long long * dims = RedisModule_PoolAlloc (ctx , ndims * sizeof (long long ));
253- for (size_t i = 0 ; i < ndims ; i ++ ) {
254- int ret = AC_GetLongLong (& dac , dims + i , 0 );
255- if (ret != AC_OK ) {
256- return RedisModule_ReplyWithError (ctx , "ERR invalid argument found in tensor shape" );
257- }
258- if (dims [i ] < 0 ) {
259- return RedisModule_ReplyWithError (ctx , "ERR negative value found in tensor shape" );
260- }
261- len *= dims [i ];
262- }
263-
264- if (argc != dims_arg + ndims &&
265- argc != dims_arg + ndims + 1 + 1 &&
266- argc != dims_arg + ndims + 1 + len ) {
267- return RedisModule_WrongArity (ctx );
268- }
269-
270- const int hasdata = !AC_IsAtEnd (& ac );
271-
272233 const char * fmtstr ;
273234 int datafmt = REDISAI_DATA_NONE ;
274- if (hasdata ) {
275- AC_GetString (& ac , & fmtstr , NULL , 0 );
276- if (strcasecmp (fmtstr , "BLOB" ) == 0 ) {
235+ int tensorAllocMode = TENSORALLOC_CALLOC ;
236+ size_t ndims = 0 ;
237+ long long len = 1 ;
238+ long long * dims = NULL ;
239+ size_t argpos = 3 ;
240+ long long remaining_args = argc - 1 ;
241+
242+ for (; argpos <= argc - 1 ; argpos ++ ){
243+ const char * opt = RedisModule_StringPtrLen (argv [argpos ], NULL );
244+ remaining_args = argc - 1 - argpos ;
245+ if (!strcasecmp (opt , "BLOB" )){
277246 datafmt = REDISAI_DATA_BLOB ;
247+ tensorAllocMode = TENSORALLOC_NONE ;
248+ // if we've found the dataformat there are no more dimensions
249+ // check right away if the arity is correct
250+ if (remaining_args != 1 ){
251+ RedisModule_Free (dims );
252+ return RedisModule_WrongArity (ctx );
253+ }
254+ argpos ++ ;
255+ break ;
278256 }
279- else if (strcasecmp (fmtstr , "VALUES" ) == 0 ) {
257+ else if (! strcasecmp (opt , "VALUES" )) {
280258 datafmt = REDISAI_DATA_VALUES ;
259+ tensorAllocMode = TENSORALLOC_ALLOC ;
260+ //if we've found the dataformat there are no more dimensions
261+ // check right away if the arity is correct
262+ if (remaining_args != len ){
263+ RedisModule_Free (dims );
264+ return RedisModule_WrongArity (ctx );
265+ }
266+ argpos ++ ;
267+ break ;
268+ } else {
269+ long long dimension = 1 ;
270+ const int retval = RedisModule_StringToLongLong (argv [argpos ],& dimension );
271+ if (retval != REDISMODULE_OK || dimension <= 0 ) {
272+ RedisModule_Free (dims );
273+ return RedisModule_ReplyWithError (ctx ,
274+ "ERR invalid or negative value found in tensor shape" );
275+ }
276+
277+ ndims ++ ;
278+ dims = RedisModule_Realloc (dims ,ndims * sizeof (long long ));
279+ dims [ndims - 1 ]= dimension ;
280+ len *= dimension ;
281281 }
282- else {
283- return RedisModule_ReplyWithError (ctx , "ERR unsupported data format" );
284- }
285282 }
286- const size_t nbytes = len * datasize ;
283+
284+ if (datafmt == REDISAI_DATA_NONE && remaining_args != 0 ){
285+ return RedisModule_ReplyWithError (ctx , "ERR unsupported data format" );
286+ }
287+
288+ const long long nbytes = len * datasize ;
287289 size_t datalen ;
288290 const char * data ;
289- RAI_Tensor * t = RAI_TensorCreate (typestr , dims , ndims , hasdata );
291+ DLDataType datatype = RAI_TensorDataTypeFromString (typestr );
292+ RAI_Tensor * t = RAI_TensorCreateWithDLDataType (datatype , dims , ndims , tensorAllocMode );
290293 if (!t ){
291294 return RedisModule_ReplyWithError (ctx , "ERR could not create tensor" );
292295 }
296+ size_t i = 0 ;
293297 switch (datafmt ){
294- case REDISAI_DATA_BLOB :
295- AC_GetString (& ac , & data , & datalen , 0 );
296- if (datalen != nbytes ){
297- RAI_TensorFree (t );
298- return RedisModule_ReplyWithError (ctx , "ERR data length does not match tensor shape and type" );
299- }
300- RAI_TensorSetData (t , data , datalen );
301- break ;
302- case REDISAI_DATA_VALUES :
303- if (argc != len + 4 + ndims ){
304- RAI_TensorFree (t );
305- return RedisModule_WrongArity (ctx );
306- }
307- DLDataType datatype = RAI_TensorDataType (t );
308-
309- long i ;
310- if (datatype .code == kDLFloat ){
311- double val ;
312- for (i = 0 ; i < len ; i ++ ){
313- int ac_ret = AC_GetDouble (& ac , & val , 0 );
314- if (ac_ret != AC_OK ){
315- RAI_TensorFree (t );
316- return RedisModule_ReplyWithError (ctx , "ERR invalid value" );
317- }
318- int ret = RAI_TensorSetValueFromDouble (t , i , val );
319- if (ret == -1 ){
320- RAI_TensorFree (t );
321- return RedisModule_ReplyWithError (ctx , "ERR cannot specify values for this datatype" );
322- }
298+ case REDISAI_DATA_BLOB :
299+ RedisModule_StringPtrLen (argv [argpos ],& datalen );
300+ if (datalen != nbytes ){
301+ RAI_TensorFree (t );
302+ return RedisModule_ReplyWithError (ctx , "ERR data length does not match tensor shape and type" );
323303 }
324- }
325- else {
326- long long val ;
327- for (i = 0 ; i < len ; i ++ ){
328- int ac_ret = AC_GetLongLong (& ac , & val , 0 );
329- if (ac_ret != AC_OK ){
330- RAI_TensorFree (t );
331- return RedisModule_ReplyWithError (ctx , "ERR invalid value" );
304+ RedisModule_RetainString (NULL ,argv [argpos ]);
305+ RAI_TensorSetDataFromRS (t ,argv [argpos ]);
306+ break ;
307+ case REDISAI_DATA_VALUES :
308+ for (; argpos <= argc - 1 ; argpos ++ ){
309+ if (datatype .code == kDLFloat ){
310+ double val ;
311+ const int retval = RedisModule_StringToDouble (argv [argpos ],& val );
312+ if (retval != REDISMODULE_OK ) {
313+ RAI_TensorFree (t );
314+ return RedisModule_ReplyWithError (ctx , "ERR invalid value" );
315+ }
316+ const int retset = RAI_TensorSetValueFromDouble (t , i , val );
317+ if (retset == -1 ){
318+ RAI_TensorFree (t );
319+ return RedisModule_ReplyWithError (ctx , "ERR cannot specify values for this datatype" );
320+ }
332321 }
333- int ret = RAI_TensorSetValueFromLongLong (t , i , val );
334- if (ret == -1 ){
335- RAI_TensorFree (t );
336- return RedisModule_ReplyWithError (ctx , "ERR cannot specify values for this datatype" );
322+ else {
323+ long long val ;
324+ const int retval = RedisModule_StringToLongLong (argv [argpos ],& val );
325+ if (retval != REDISMODULE_OK ) {
326+ RAI_TensorFree (t );
327+ return RedisModule_ReplyWithError (ctx , "ERR invalid value" );
328+ }
329+ const int retset = RAI_TensorSetValueFromLongLong (t , i , val );
330+ if (retset == -1 ){
331+ RAI_TensorFree (t );
332+ return RedisModule_ReplyWithError (ctx , "ERR cannot specify values for this datatype" );
333+ }
337334 }
335+ i ++ ;
338336 }
339- }
340- break ;
341- default :
342- // default does not require tensor data setting since calloc setted it to 0
343- break ;
337+ break ;
338+ default :
339+ // default does not require tensor data setting since calloc setted it to 0
340+ break ;
341+ }
342+ if ( RedisModule_ModuleTypeSetValue (key , RedisAI_TensorType , t ) != REDISMODULE_OK ){
343+ return RedisModule_ReplyWithError (ctx , "ERR could not save tensor" );
344344 }
345-
346- RedisModule_ModuleTypeSetValue (key , RedisAI_TensorType , t );
347345 RedisModule_CloseKey (key );
348346 RedisModule_ReplyWithSimpleString (ctx , "OK" );
349347 RedisModule_ReplicateVerbatim (ctx );
350-
351348 return REDISMODULE_OK ;
352349}
353350
@@ -1891,6 +1888,7 @@ int RedisAI_Config_LoadBackend(RedisModuleCtx *ctx, RedisModuleString **argv, in
18911888 return RedisModule_ReplyWithError (ctx , "ERR error loading backend" );
18921889}
18931890
1891+
18941892int RedisAI_Config_BackendsPath (RedisModuleCtx * ctx , const char * path ) {
18951893 RedisModule_AutoMemory (ctx );
18961894
0 commit comments