@@ -2,6 +2,8 @@ package redisai
22
33import (
44 "github.com/RedisAI/redisai-go/redisai/converters"
5+ "github.com/google/go-cmp/cmp"
6+ "reflect"
57 "testing"
68)
79
@@ -49,3 +51,79 @@ func Test_tensorSetFlatArgs(t *testing.T) {
4951 })
5052 }
5153}
54+
55+ func TestTensorGetTypeStrFromType (t * testing.T ) {
56+ type args struct {
57+ dtype reflect.Type
58+ }
59+ tests := []struct {
60+ name string
61+ args args
62+ wantTypestr string
63+ wantErr bool
64+ }{
65+ {"uint8" , args {reflect .TypeOf (([]uint8 )(nil ))}, TypeUint8 , false },
66+ {"uint8" , args {reflect .TypeOf (([]byte )(nil ))}, TypeUint8 , false },
67+ {"uint8" , args {reflect .TypeOf (([]uint16 )(nil ))}, TypeUint16 , false },
68+ {"uint8" , args {reflect .TypeOf (([]int )(nil ))}, TypeInt32 , false },
69+ {"uint8" , args {reflect .TypeOf (([]int8 )(nil ))}, TypeInt8 , false },
70+ {"uint8" , args {reflect .TypeOf (([]int16 )(nil ))}, TypeInt16 , false },
71+ {"uint8" , args {reflect .TypeOf (([]int32 )(nil ))}, TypeInt32 , false },
72+ {"uint8" , args {reflect .TypeOf (([]int64 )(nil ))}, TypeInt64 , false },
73+ {"uint8" , args {reflect .TypeOf (([]uint8 )(nil ))}, TypeUint8 , false },
74+ {"uint8" , args {reflect .TypeOf (([]uint16 )(nil ))}, TypeUint16 , false },
75+ {"uint8" , args {reflect .TypeOf (([]float32 )(nil ))}, TypeFloat32 , false },
76+ {"uint8" , args {reflect .TypeOf (([]float64 )(nil ))}, TypeFloat64 , false },
77+ {"uint8" , args {reflect .TypeOf (([]string )(nil ))}, "" , true },
78+ }
79+ for _ , tt := range tests {
80+ t .Run (tt .name , func (t * testing.T ) {
81+ gotTypestr , err := TensorGetTypeStrFromType (tt .args .dtype )
82+ if (err != nil ) != tt .wantErr {
83+ t .Errorf ("TensorGetTypeStrFromType() error = %v, wantErr %v" , err , tt .wantErr )
84+ return
85+ }
86+ if gotTypestr != tt .wantTypestr {
87+ t .Errorf ("TensorGetTypeStrFromType() gotTypestr = %v, want %v" , gotTypestr , tt .wantTypestr )
88+ }
89+ })
90+ }
91+ }
92+
93+ func TestProcessTensorGetReply (t * testing.T ) {
94+ type args struct {
95+ reply interface {}
96+ errIn error
97+ }
98+ tests := []struct {
99+ name string
100+ args args
101+ wantDtype string
102+ wantShape []int
103+ wantData interface {}
104+ wantErr bool
105+ }{
106+ {"empty" , args {}, "" , nil , nil , true },
107+ {"negative-wrong-reply" , args {[]interface {}{[]interface {}{[]byte ("serie 1" ), []interface {}{}, []interface {}{[]interface {}{[]byte ("AA" ), []byte ("1" )}}}}, nil }, "" , nil , nil , true },
108+ {"negative-wrong-reply" , args {[]interface {}{[]byte ("dtype" ), []interface {}{[]byte ("dtype" ), []byte ("1" )}}, nil }, "" , nil , nil , true },
109+ {"negative-wrong-shape" , args {[]interface {}{[]byte ("shape" ), []byte ("string" )}, nil }, "" , nil , nil , true },
110+ {"negative-wrong-blob" , args {[]interface {}{[]byte ("dtype" ), []interface {}{[]byte ("dtype" ), []byte ("1" )}}, nil }, "" , nil , nil , true },
111+ }
112+ for _ , tt := range tests {
113+ t .Run (tt .name , func (t * testing.T ) {
114+ gotErr , gotDtype , gotShape , gotData := ProcessTensorGetReply (tt .args .reply , tt .args .errIn )
115+ if gotErr != nil && ! tt .wantErr {
116+ t .Errorf ("ProcessTensorGetReply() gotErr = %v, want %v" , gotErr , tt .wantErr )
117+ }
118+ if diff := cmp .Diff (tt .wantDtype , gotDtype ); diff != "" {
119+ t .Errorf ("ProcessTensorGetReply() gotDtype mismatch (-want +got):\n %s" , diff )
120+ }
121+ if diff := cmp .Diff (tt .wantShape , gotShape ); diff != "" {
122+ t .Errorf ("ProcessTensorGetReply() gotShape mismatch (-want +got):\n %s" , diff )
123+ }
124+ if diff := cmp .Diff (tt .wantData , gotData ); diff != "" {
125+ t .Errorf ("ProcessTensorGetReply() gotData mismatch (-want +got):\n %s" , diff )
126+ }
127+ })
128+ }
129+ }
0 commit comments