Skip to content

Commit dd710d8

Browse files
authored
Support for [TAG tag] on AI.SCRIPTSET、AI.MODELSET (#20)
* Support for [TAG tag] on AI.SCRIPTSET、AI.MODELSET
1 parent 81ac224 commit dd710d8

File tree

8 files changed

+107
-28
lines changed

8 files changed

+107
-28
lines changed

redisai/client_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func getConnectionDetails() (host string, password string) {
2323
}
2424

2525
func createPool() *redis.Pool {
26-
host,_ := getConnectionDetails()
26+
host, _ := getConnectionDetails()
2727
cpool := &redis.Pool{
2828
MaxIdle: 3,
2929
IdleTimeout: 240 * time.Second,
@@ -57,12 +57,12 @@ func getTLSdetails() (tlsready bool, tls_cert string, tls_key string, tls_cacert
5757
}
5858

5959
func createTestClient() *Client {
60-
host,_ := getConnectionDetails()
60+
host, _ := getConnectionDetails()
6161
return Connect(host, nil)
6262
}
6363

6464
func TestConnect(t *testing.T) {
65-
host,_ := getConnectionDetails()
65+
host, _ := getConnectionDetails()
6666
cpool1 := &redis.Pool{
6767
MaxIdle: 3,
6868
IdleTimeout: 240 * time.Second,

redisai/commands.go

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ func (c *Client) TensorGetBlob(name string) (dt string, shape []int64, data []by
8383

8484
// ModelSet sets a RedisAI model from a blob
8585
func (c *Client) ModelSet(keyName, backend, device string, data []byte, inputs, outputs []string) (err error) {
86-
args := modelSetFlatArgs(keyName, backend, device, inputs, outputs, data)
86+
args := modelSetFlatArgs(keyName, backend, device, "", inputs, outputs, data)
8787
_, err = c.DoOrSend("AI.MODELSET", args, nil)
8888
return
8989
}
@@ -95,15 +95,21 @@ func (c *Client) ModelSetFromModel(keyName string, model ModelInterface) (err er
9595
return
9696
}
9797

98+
// ModelGet gets a RedisAI model from the RedisAI server
99+
// The reply will an array, containing at
100+
// - position 0 the backend used by the model as a String
101+
// - position 1 the device used to execute the model as a String
102+
// - position 2 the model's tag as a String
103+
// - position 3 a blob containing the serialized model (when called with the BLOB argument) as a String
98104
func (c *Client) ModelGet(keyName string) (data []interface{}, err error) {
99105
var reply interface{}
100-
data = make([]interface{}, 3)
106+
data = make([]interface{}, 4)
101107
args := modelGetFlatArgs(keyName)
102108
reply, err = c.DoOrSend("AI.MODELGET", args, nil)
103109
if err != nil || reply == nil {
104110
return
105111
}
106-
err, data[0], data[1], data[2] = modelGetParseReply(reply)
112+
err, data[0], data[1], data[2], data[3] = modelGetParseReply(reply)
107113
return
108114
}
109115

@@ -138,6 +144,13 @@ func (c *Client) ScriptSet(name string, device string, script_source string) (er
138144
return
139145
}
140146

147+
// ScriptSetWithTag sets a RedisAI script from a blob with tag
148+
func (c *Client) ScriptSetWithTag(name string, device string, script_source string, tag string) (err error) {
149+
args := redis.Args{}.Add(name, device, "TAG", tag, "SOURCE", script_source)
150+
_, err = c.DoOrSend("AI.SCRIPTSET", args, nil)
151+
return
152+
}
153+
141154
func (c *Client) ScriptGet(name string) (data map[string]string, err error) {
142155
args := redis.Args{}.Add(name, "META", "SOURCE")
143156
respInitial, err := c.DoOrSend("AI.SCRIPTGET", args, nil)

redisai/commands_test.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,12 @@ func TestCommand_ModelGet(t *testing.T) {
476476
args args
477477
wantBackend string
478478
wantDevice string
479+
wantTag string
479480
wantData []byte
480481
wantErr bool
481482
}{
482-
{keyModelUnexistent1, args{keyModelUnexistent1}, BackendTF, DeviceCPU, data, true},
483-
{keyModel1, args{keyModel1}, BackendTF, DeviceCPU, data, false},
483+
{keyModelUnexistent1, args{keyModelUnexistent1}, BackendTF, DeviceCPU, "", data, true},
484+
{keyModel1, args{keyModel1}, BackendTF, DeviceCPU, "", data, false},
484485
}
485486
for _, tt := range tests {
486487
t.Run(tt.name, func(t *testing.T) {
@@ -501,8 +502,13 @@ func TestCommand_ModelGet(t *testing.T) {
501502
}
502503
}
503504
if !tt.wantErr {
504-
if !reflect.DeepEqual(gotData[2], tt.wantData) {
505-
t.Errorf("ModelGetToModel() gotData = %v, want %v. gotData Type %v, want Type %v.", gotData[2], tt.wantData, reflect.TypeOf(gotData[2]), reflect.TypeOf(tt.wantData))
505+
if !reflect.DeepEqual(gotData[2], tt.wantTag) {
506+
t.Errorf("ModelGetToModel() gotTag = %v, want %v. gotTag Type %v, want Type %v.", gotData[2], tt.wantTag, reflect.TypeOf(gotData[2]), reflect.TypeOf(tt.wantTag))
507+
}
508+
}
509+
if !tt.wantErr {
510+
if !reflect.DeepEqual(gotData[3], tt.wantData) {
511+
t.Errorf("ModelGetToModel() gotData = %v, want %v. gotData Type %v, want Type %v.", gotData[3], tt.wantData, reflect.TypeOf(gotData[3]), reflect.TypeOf(tt.wantData))
506512
}
507513
}
508514

@@ -621,8 +627,12 @@ func TestCommand_FullFromModelFlow(t *testing.T) {
621627
assert.Nil(t, err)
622628
model1.SetInputs([]string{"transaction", "reference"})
623629
model1.SetOutputs([]string{"output"})
630+
model1.SetTag("financialTag")
624631
err = client.ModelSetFromModel("financialNet1", model1)
625632
assert.Nil(t, err)
633+
model2 := implementations.NewEmptyModel()
634+
err = client.ModelGetToModel("financialNet1", model2)
635+
assert.Equal(t, model1.Tag(), model2.Tag())
626636
}
627637

628638
func TestCommand_ScriptDel(t *testing.T) {
@@ -684,6 +694,14 @@ func TestCommand_ScriptGet(t *testing.T) {
684694
return
685695
}
686696

697+
keyScript2 := "test:ScriptGet:2"
698+
keyScriptTag := "keyScriptTag"
699+
err = simpleClient.ScriptSetWithTag(keyScript2, DeviceCPU, scriptBin, keyScriptTag)
700+
if err != nil {
701+
t.Errorf("Error preparing for ScriptGet(), while issuing ScriptSet. error = %v", err)
702+
return
703+
}
704+
687705
type args struct {
688706
name string
689707
}
@@ -692,11 +710,13 @@ func TestCommand_ScriptGet(t *testing.T) {
692710
args args
693711
wantDeviceType string
694712
wantData string
713+
wantTag string
695714
wantErr bool
696715
}{
697-
{keyScript, args{keyScript}, DeviceCPU, "", false},
698-
{keyScriptPipelined, args{keyScript}, DeviceCPU, "", false},
699-
{keyScriptEmpty, args{keyScriptEmpty}, DeviceCPU, "", true},
716+
{keyScript, args{keyScript}, DeviceCPU, "", "", false},
717+
{keyScriptPipelined, args{keyScript}, DeviceCPU, "", "", false},
718+
{keyScriptEmpty, args{keyScriptEmpty}, DeviceCPU, "", "", true},
719+
{keyScriptTag, args{keyScript2}, DeviceCPU, "", keyScriptTag, false},
700720
}
701721
for _, tt := range tests {
702722
t.Run(tt.name, func(t *testing.T) {
@@ -706,13 +726,17 @@ func TestCommand_ScriptGet(t *testing.T) {
706726
t.Errorf("ScriptGet() error = %v, wantErr %v", err, tt.wantErr)
707727
return
708728
}
729+
709730
if tt.wantErr == false {
710731
if !reflect.DeepEqual(gotData["device"], tt.wantDeviceType) {
711732
t.Errorf("ScriptGet() gotData = %v, want %v", gotData["device"], tt.wantDeviceType)
712733
}
713734
if !reflect.DeepEqual(gotData["source"], tt.wantData) {
714735
t.Errorf("ScriptGet() gotData = %v, want %v", gotData["source"], tt.wantData)
715736
}
737+
if !reflect.DeepEqual(gotData["tag"], tt.wantTag) {
738+
t.Errorf("ScriptGet() gotData = %v, want %v", gotData["tag"], tt.wantTag)
739+
}
716740
}
717741

718742
})
@@ -1014,4 +1038,4 @@ func TestClient_ModelRun(t *testing.T) {
10141038
}
10151039
})
10161040
}
1017-
}
1041+
}

redisai/example_client_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ func ExampleConnect() {
3434
// Output: [1.1 2.2 3.3 4.4]
3535
}
3636

37-
3837
//Example of how to establish an connection with a shared pool to the RedisAI Server
3938
func ExampleConnect_pool() {
4039

redisai/example_commands_test.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ func ExampleClient_ModelGet() {
100100
device := reply[1]
101101
// print the error (should be <nil>)
102102
fmt.Println(err)
103-
fmt.Println(backend,device)
103+
fmt.Println(backend, device)
104104

105105
// Output:
106106
// <nil>
@@ -185,7 +185,6 @@ func ExampleClient_ModelRun() {
185185
// <nil>
186186
}
187187

188-
189188
func ExampleClient_Info() {
190189
// Create a client.
191190
client := redisai.Connect("redis://localhost:6379", nil)
@@ -217,4 +216,4 @@ func ExampleClient_Info() {
217216
// <nil>
218217
// <nil>
219218
// Total runs: 1
220-
}
219+
}

redisai/implementations/AIModel.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ type AIModel struct {
88
blob []byte
99
inputs []string
1010
outputs []string
11+
tag string
1112
}
1213

1314
func (m *AIModel) Outputs() []string {
@@ -50,6 +51,14 @@ func (m *AIModel) SetBackend(backend string) {
5051
m.backend = backend
5152
}
5253

54+
func (m *AIModel) Tag() string {
55+
return m.tag
56+
}
57+
58+
func (m *AIModel) SetTag(tag string) {
59+
m.tag = tag
60+
}
61+
5362
func NewModel(backend string, device string) *AIModel {
5463
return &AIModel{backend: backend, device: device}
5564
}

redisai/model.go

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@ type ModelInterface interface {
1717
SetDevice(device string)
1818
Backend() string
1919
SetBackend(backend string)
20+
Tag() string
21+
SetTag(tag string)
2022
}
2123

22-
func modelSetFlatArgs(keyName, backend, device string, inputs, outputs []string, blob []byte) redis.Args {
24+
func modelSetFlatArgs(keyName, backend, device, tag string, inputs, outputs []string, blob []byte) redis.Args {
2325
args := redis.Args{}.Add(keyName, backend, device)
26+
if len(tag) > 0 {
27+
args = args.Add("TAG", tag)
28+
}
2429
if len(inputs) > 0 {
2530
args = args.Add("INPUTS").AddFlat(inputs)
2631
}
@@ -33,7 +38,26 @@ func modelSetFlatArgs(keyName, backend, device string, inputs, outputs []string,
3338
}
3439

3540
func modelSetInterfaceArgs(keyName string, modelInterface ModelInterface) redis.Args {
36-
return modelSetFlatArgs(keyName, modelInterface.Backend(), modelInterface.Device(), modelInterface.Inputs(), modelInterface.Outputs(), modelInterface.Blob())
41+
args := redis.Args{keyName}
42+
if len(modelInterface.Backend()) > 0 {
43+
args = args.Add(modelInterface.Backend())
44+
}
45+
if len(modelInterface.Device()) > 0 {
46+
args = args.Add(modelInterface.Device())
47+
}
48+
if len(modelInterface.Tag()) > 0 {
49+
args = args.Add("TAG", modelInterface.Tag())
50+
}
51+
if len(modelInterface.Inputs()) > 0 {
52+
args = args.Add("INPUTS").AddFlat(modelInterface.Inputs())
53+
}
54+
if len(modelInterface.Outputs()) > 0 {
55+
args = args.Add("OUTPUTS").AddFlat(modelInterface.Outputs())
56+
}
57+
if modelInterface.Blob() != nil {
58+
args = args.Add("BLOB", modelInterface.Blob())
59+
}
60+
return args
3761
}
3862

3963
func modelRunFlatArgs(name string, inputTensorNames, outputTensorNames []string) redis.Args {
@@ -51,18 +75,20 @@ func modelRunFlatArgs(name string, inputTensorNames, outputTensorNames []string)
5175
func modelGetParseToInterface(reply interface{}, model ModelInterface) (err error) {
5276
var backend string
5377
var device string
78+
var tag string
5479
var blob []byte
55-
err, backend, device, blob = modelGetParseReply(reply)
80+
err, backend, device, tag, blob = modelGetParseReply(reply)
5681
if err != nil {
5782
return err
5883
}
5984
model.SetBackend(backend)
6085
model.SetDevice(device)
86+
model.SetTag(tag)
6187
model.SetBlob(blob)
6288
return
6389
}
6490

65-
func modelGetParseReply(reply interface{}) (err error, backend string, device string, blob []byte) {
91+
func modelGetParseReply(reply interface{}) (err error, backend string, device string, tag string, blob []byte) {
6692
var replySlice []interface{}
6793
var key string
6894
replySlice, err = redis.Values(reply, err)
@@ -90,6 +116,11 @@ func modelGetParseReply(reply interface{}) (err error, backend string, device st
90116
if err != nil {
91117
return
92118
}
119+
case "tag":
120+
tag, err = redis.String(replySlice[pos+1], err)
121+
if err != nil {
122+
return
123+
}
93124
}
94125
}
95126
return

redisai/model_test.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@ func Test_modelGetParseReply(t *testing.T) {
1414
args args
1515
wantBackend string
1616
wantDevice string
17+
wantTag string
1718
wantBlob []byte
1819
wantErr bool
1920
}{
20-
{"empty", args{}, "", "", nil, true},
21-
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}}, "", "", nil, true},
22-
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
23-
{"negative-wrong-device", args{[]interface{}{[]byte("device"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
24-
{"negative-wrong-blob", args{[]interface{}{[]byte("blob"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", nil, true},
21+
{"empty", args{}, "", "", "", nil, true},
22+
{"negative-wrong-reply", args{[]interface{}{[]interface{}{[]byte("serie 1"), []interface{}{}, []interface{}{[]interface{}{[]byte("AA"), []byte("1")}}}}}, "", "", "", nil, true},
23+
{"negative-wrong-reply", args{[]interface{}{[]byte("dtype"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
24+
{"negative-wrong-device", args{[]interface{}{[]byte("device"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
25+
{"negative-wrong-blob", args{[]interface{}{[]byte("blob"), []interface{}{[]byte("dtype"), []byte("1")}}}, "", "", "", nil, true},
2526
}
2627
for _, tt := range tests {
2728
t.Run(tt.name, func(t *testing.T) {
28-
gotErr, gotBackend, gotDevice, gotBlob := modelGetParseReply(tt.args.reply)
29+
gotErr, gotBackend, gotDevice, gotTag, gotBlob := modelGetParseReply(tt.args.reply)
2930
if gotErr != nil && !tt.wantErr {
3031
t.Errorf("modelGetParseReply() gotErr = %v, want %v", gotErr, tt.wantErr)
3132
}
@@ -35,6 +36,9 @@ func Test_modelGetParseReply(t *testing.T) {
3536
if gotDevice != tt.wantDevice {
3637
t.Errorf("modelGetParseReply() gotDevice = %v, want %v", gotDevice, tt.wantDevice)
3738
}
39+
if gotTag != tt.wantTag {
40+
t.Errorf("modelGetParseReply() gotTag = %v, want %v", gotTag, tt.wantTag)
41+
}
3842
if !reflect.DeepEqual(gotBlob, tt.wantBlob) {
3943
t.Errorf("modelGetParseReply() gotBlob = %v, want %v", gotBlob, tt.wantBlob)
4044
}

0 commit comments

Comments
 (0)