@@ -22,13 +22,50 @@ type TensorInterface interface {
2222 // Len returns the number of elements in the tensor.
2323 Len () int
2424
25- TypeString () string
26- SetTypeString (typestr string )
25+ Dtype () reflect.Type
2726
27+ // Data returns the underlying tensor data
2828 Data () interface {}
2929 SetData (interface {})
3030}
3131
32+ func TensorGetTypeStrFromType (dtype reflect.Type ) (typestr string , err error ) {
33+ switch dtype {
34+ case reflect .TypeOf (([]uint8 )(nil )):
35+ typestr = TypeUint8
36+ case reflect .TypeOf (([]byte )(nil )):
37+ typestr = TypeUint8
38+ case reflect .TypeOf (([]int )(nil )):
39+ typestr = TypeInt32
40+ case reflect .TypeOf (([]int8 )(nil )):
41+ typestr = TypeInt8
42+ case reflect .TypeOf (([]int16 )(nil )):
43+ typestr = TypeInt16
44+ case reflect .TypeOf (([]int32 )(nil )):
45+ typestr = TypeInt32
46+ case reflect .TypeOf (([]int64 )(nil )):
47+ typestr = TypeInt64
48+ case reflect .TypeOf (([]uint )(nil )):
49+ typestr = TypeUint8
50+ case reflect .TypeOf (([]uint16 )(nil )):
51+ typestr = TypeUint16
52+ case reflect .TypeOf (([]float32 )(nil )):
53+ typestr = TypeFloat32
54+ case reflect .TypeOf (([]float64 )(nil )):
55+ typestr = TypeFloat64
56+ case reflect .TypeOf (([]uint32 )(nil )):
57+ fallthrough
58+ // unsupported data type
59+ case reflect .TypeOf (([]uint64 )(nil )):
60+ fallthrough
61+ // unsupported data type
62+
63+ default :
64+ err = fmt .Errorf ("redisai Tensor does not support the following type %v" , dtype )
65+ }
66+ return
67+ }
68+
3269func tensorSetFlatArgs (name string , dt string , dims []int , data interface {}) (redis.Args , error ) {
3370 args := redis.Args {}
3471 var err error = nil
@@ -74,13 +111,16 @@ func tensorSetFlatArgs(name string, dt string, dims []int, data interface{}) (re
74111 return args , err
75112}
76113
77- func tensorSetInterfaceArgs (keyName string , tensorInterface TensorInterface ) (redis.Args , error ) {
78- return tensorSetFlatArgs (keyName , tensorInterface .TypeString (), tensorInterface .Shape (), tensorInterface .Data ())
114+ func tensorSetInterfaceArgs (keyName string , tensorInterface TensorInterface ) (args redis.Args , err error ) {
115+ typestr , err := TensorGetTypeStrFromType (tensorInterface .Dtype ())
116+ if err != nil {
117+ return
118+ }
119+ return tensorSetFlatArgs (keyName , typestr , tensorInterface .Shape (), tensorInterface .Data ())
79120}
80121
81122func tensorGetParseToInterface (reply interface {}, tensor TensorInterface ) (err error ) {
82- err , dtype , shape , data := ProcessTensorGetReply (reply , err )
83- tensor .SetTypeString (dtype )
123+ err , _ , shape , data := ProcessTensorGetReply (reply , err )
84124 tensor .SetShape (shape )
85125 tensor .SetData (data )
86126 return
0 commit comments