Skip to content

Commit bebd0a2

Browse files
[add] extended negative testing on tensor, model, and client (#10)
1 parent 1092637 commit bebd0a2

File tree

3 files changed

+187
-0
lines changed

3 files changed

+187
-0
lines changed

redisai/client_test.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,3 +202,48 @@ func TestClient_Receive(t *testing.T) {
202202
})
203203
}
204204
}
205+
206+
func TestClient_DisablePipeline(t *testing.T) {
207+
// Create a client.
208+
client := Connect("redis://localhost:6379", nil)
209+
210+
// Enable pipeline of commands on the client, autoFlushing at 3 commands
211+
client.Pipeline(3)
212+
213+
// Set a tensor
214+
// AI.TENSORSET foo FLOAT 2 2 VALUES 1.1 2.2 3.3 4.4
215+
err := client.TensorSet("foo1", TypeFloat, []int{2, 2}, []float32{1.1, 2.2, 3.3, 4.4})
216+
if err != nil {
217+
t.Errorf("TensorSet() error = %v", err)
218+
}
219+
// AI.TENSORSET foo2 FLOAT 1" 1 VALUES 1.1
220+
err = client.TensorSet("foo2", TypeFloat, []int{1, 1}, []float32{1.1})
221+
if err != nil {
222+
t.Errorf("TensorSet() error = %v", err)
223+
}
224+
// AI.TENSORGET foo2 META
225+
_, err = client.TensorGet("foo2", TensorContentTypeMeta)
226+
if err != nil {
227+
t.Errorf("TensorGet() error = %v", err)
228+
}
229+
// Ignore the AI.TENSORSET Reply
230+
_, err = client.Receive()
231+
if err != nil {
232+
t.Errorf("Receive() error = %v", err)
233+
}
234+
// Ignore the AI.TENSORSET Reply
235+
_, err = client.Receive()
236+
if err != nil {
237+
t.Errorf("Receive() error = %v", err)
238+
}
239+
err, _, _, _ = ProcessTensorGetReply(client.Receive())
240+
if err != nil {
241+
t.Errorf("ProcessTensorGetReply() error = %v", err)
242+
}
243+
244+
err = client.DisablePipeline()
245+
if err != nil {
246+
t.Errorf("DisablePipeline() error = %v", err)
247+
}
248+
249+
}

redisai/model_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package redisai
2+
3+
import (
4+
"reflect"
5+
"testing"
6+
)
7+
8+
func Test_modelGetParseReply(t *testing.T) {
9+
type args struct {
10+
reply interface{}
11+
}
12+
tests := []struct {
13+
name string
14+
args args
15+
wantBackend string
16+
wantDevice string
17+
wantBlob []byte
18+
wantErr bool
19+
}{
20+
{"empty", args{}, "", "", nil, true},
21+
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}}, "", "", nil, true},
22+
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
23+
{"negative-wrong-device", args{[]interface{}{[]byte("device"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
24+
{"negative-wrong-blob", args{[]interface{}{[]byte("blob"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
25+
}
26+
for _, tt := range tests {
27+
t.Run(tt.name, func(t *testing.T) {
28+
gotErr, gotBackend, gotDevice, gotBlob := modelGetParseReply(tt.args.reply)
29+
if gotErr != nil && !tt.wantErr {
30+
t.Errorf("modelGetParseReply() gotErr = %v, want %v", gotErr, tt.wantErr)
31+
}
32+
if gotBackend != tt.wantBackend {
33+
t.Errorf("modelGetParseReply() gotBackend = %v, want %v", gotBackend, tt.wantBackend)
34+
}
35+
if gotDevice != tt.wantDevice {
36+
t.Errorf("modelGetParseReply() gotDevice = %v, want %v", gotDevice, tt.wantDevice)
37+
}
38+
if !reflect.DeepEqual(gotBlob, tt.wantBlob) {
39+
t.Errorf("modelGetParseReply() gotBlob = %v, want %v", gotBlob, tt.wantBlob)
40+
}
41+
})
42+
}
43+
}
44+
45+
func Test_modelGetParseToInterface(t *testing.T) {
46+
type args struct {
47+
reply interface{}
48+
model ModelInterface
49+
}
50+
tests := []struct {
51+
name string
52+
args args
53+
wantErr bool
54+
}{
55+
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}, nil}, true},
56+
}
57+
for _, tt := range tests {
58+
t.Run(tt.name, func(t *testing.T) {
59+
if err := modelGetParseToInterface(tt.args.reply, tt.args.model); (err != nil) != tt.wantErr {
60+
t.Errorf("modelGetParseToInterface() error = %v, wantErr %v", err, tt.wantErr)
61+
}
62+
})
63+
}
64+
}

redisai/tensor_test.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package redisai
22

33
import (
44
"github.com/RedisAI/redisai-go/redisai/converters"
5+
"github.com/google/go-cmp/cmp"
6+
"reflect"
57
"testing"
68
)
79

@@ -49,3 +51,79 @@ func Test_tensorSetFlatArgs(t *testing.T) {
4951
})
5052
}
5153
}
54+
55+
func TestTensorGetTypeStrFromType(t *testing.T) {
56+
type args struct {
57+
dtype reflect.Type
58+
}
59+
tests := []struct {
60+
name string
61+
args args
62+
wantTypestr string
63+
wantErr bool
64+
}{
65+
{"uint8", args{reflect.TypeOf(([]uint8)(nil))}, TypeUint8, false},
66+
{"uint8", args{reflect.TypeOf(([]byte)(nil))}, TypeUint8, false},
67+
{"uint8", args{reflect.TypeOf(([]uint16)(nil))}, TypeUint16, false},
68+
{"uint8", args{reflect.TypeOf(([]int)(nil))}, TypeInt32, false},
69+
{"uint8", args{reflect.TypeOf(([]int8)(nil))}, TypeInt8, false},
70+
{"uint8", args{reflect.TypeOf(([]int16)(nil))}, TypeInt16, false},
71+
{"uint8", args{reflect.TypeOf(([]int32)(nil))}, TypeInt32, false},
72+
{"uint8", args{reflect.TypeOf(([]int64)(nil))}, TypeInt64, false},
73+
{"uint8", args{reflect.TypeOf(([]uint8)(nil))}, TypeUint8, false},
74+
{"uint8", args{reflect.TypeOf(([]uint16)(nil))}, TypeUint16, false},
75+
{"uint8", args{reflect.TypeOf(([]float32)(nil))}, TypeFloat32, false},
76+
{"uint8", args{reflect.TypeOf(([]float64)(nil))}, TypeFloat64, false},
77+
{"uint8", args{reflect.TypeOf(([]string)(nil))}, "", true},
78+
}
79+
for _, tt := range tests {
80+
t.Run(tt.name, func(t *testing.T) {
81+
gotTypestr, err := TensorGetTypeStrFromType(tt.args.dtype)
82+
if (err != nil) != tt.wantErr {
83+
t.Errorf("TensorGetTypeStrFromType() error = %v, wantErr %v", err, tt.wantErr)
84+
return
85+
}
86+
if gotTypestr != tt.wantTypestr {
87+
t.Errorf("TensorGetTypeStrFromType() gotTypestr = %v, want %v", gotTypestr, tt.wantTypestr)
88+
}
89+
})
90+
}
91+
}
92+
93+
func TestProcessTensorGetReply(t *testing.T) {
94+
type args struct {
95+
reply interface{}
96+
errIn error
97+
}
98+
tests := []struct {
99+
name string
100+
args args
101+
wantDtype string
102+
wantShape []int
103+
wantData interface{}
104+
wantErr bool
105+
}{
106+
{"empty", args{}, "", nil, nil, true},
107+
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}, nil}, "", nil, nil, true},
108+
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}, nil}, "", nil, nil, true},
109+
{"negative-wrong-shape", args{[]interface{}{[]byte("shape"), []byte("string")}, nil}, "", nil, nil, true},
110+
{"negative-wrong-blob", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}, nil}, "", nil, nil, true},
111+
}
112+
for _, tt := range tests {
113+
t.Run(tt.name, func(t *testing.T) {
114+
gotErr, gotDtype, gotShape, gotData := ProcessTensorGetReply(tt.args.reply, tt.args.errIn)
115+
if gotErr != nil && !tt.wantErr {
116+
t.Errorf("ProcessTensorGetReply() gotErr = %v, want %v", gotErr, tt.wantErr)
117+
}
118+
if diff := cmp.Diff(tt.wantDtype, gotDtype); diff != "" {
119+
t.Errorf("ProcessTensorGetReply() gotDtype mismatch (-want +got):\n%s", diff)
120+
}
121+
if diff := cmp.Diff(tt.wantShape, gotShape); diff != "" {
122+
t.Errorf("ProcessTensorGetReply() gotShape mismatch (-want +got):\n%s", diff)
123+
}
124+
if diff := cmp.Diff(tt.wantData, gotData); diff != "" {
125+
t.Errorf("ProcessTensorGetReply() gotData mismatch (-want +got):\n%s", diff)
126+
}
127+
})
128+
}
129+
}

0 commit comments

Comments
 (0)