Skip to content

Commit 1092637

Browse files
authored
Add support for AI.DAGRUN AI.DAGRUN_RO (#8)
* Add support for AI.DAGRUN AI.DAGRUN_RO * Add DagCommandInterface
1 parent abf6b5d commit 1092637

File tree

3 files changed

+208
-0
lines changed

3 files changed

+208
-0
lines changed

redisai/commands.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,42 @@ func (c *Client) Info(key string) (map[string]string, error) {
198198
func (c *Client) ResetStat(key string) (string, error) {
199199
return redis.String(c.DoOrSend("AI.INFO", redis.Args{key, "RESETSTAT"}, nil))
200200
}
201+
202+
// Direct acyclic graph of operations to run within RedisAI
203+
func (c *Client) DagRun(loadKeys []string, persistKeys []string, dagCommandInterface DagCommandInterface) ([]interface{}, error) {
204+
commandArgs, err := dagCommandInterface.FlatArgs()
205+
if err != nil {
206+
return nil, err
207+
}
208+
args := AddDagRunArgs(loadKeys, persistKeys, commandArgs)
209+
reply, err := c.DoOrSend("AI.DAGRUN", args, nil)
210+
return dagCommandInterface.ParseReply(reply, err)
211+
}
212+
213+
// The command is a read-only variant of AI.DAGRUN
214+
func (c *Client) DagRunRO(loadKeys []string, dagCommandInterface DagCommandInterface) ([]interface{}, error) {
215+
commandArgs, err := dagCommandInterface.FlatArgs()
216+
if err != nil {
217+
return nil, err
218+
}
219+
args := AddDagRunArgs(loadKeys, nil, commandArgs)
220+
reply, err := c.DoOrSend("AI.DAGRUN_RO", args, nil)
221+
return dagCommandInterface.ParseReply(reply, err)
222+
}
223+
224+
// AddDagRunArgs for AI.DAGRUN and DAGRUN_RO commands.
225+
func AddDagRunArgs(loadKeys []string, persistKeys []string, commandArgs redis.Args) redis.Args {
226+
args := redis.Args{}
227+
if loadKeys != nil {
228+
args = args.Add("LOAD", len(loadKeys)).AddFlat(loadKeys)
229+
}
230+
231+
if persistKeys != nil {
232+
args = args.Add("PERSIST", len(persistKeys)).AddFlat(persistKeys)
233+
}
234+
235+
if commandArgs != nil {
236+
args = args.AddFlat(commandArgs)
237+
}
238+
return args
239+
}

redisai/commands_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,3 +868,113 @@ func TestCommand_Info(t *testing.T) {
868868
ret, err = c.ResetStat("notExits")
869869
assert.NotNil(t, err)
870870
}
871+
872+
func TestCommand_DagRun(t *testing.T) {
873+
c := createTestClient()
874+
keyModel1 := "test:DagRun:mymodel:1"
875+
data, err := ioutil.ReadFile("./../tests/test_data/graph.pb")
876+
if err != nil {
877+
t.Errorf("Error preparing for Info(), while issuing ModelSet. error = %v", err)
878+
return
879+
}
880+
err = c.ModelSet(keyModel1, BackendTF, DeviceCPU, data, []string{"a", "b"}, []string{"mul"})
881+
err = c.TensorSet("persisted_tensor_1", TypeFloat32, []int{1, 2}, []float32{5, 10})
882+
assert.Nil(t, err)
883+
884+
type args struct {
885+
loadKeys []string
886+
persistKeys []string
887+
dagCommandInterface DagCommandInterface
888+
}
889+
tests := []struct {
890+
name string
891+
args args
892+
wantErr bool
893+
}{
894+
{"t_wrong_number", args{[]string{"notnumber"}, nil, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true},
895+
{"t_load", args{[]string{"persisted_tensor_1"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, false},
896+
{"t_load_err", args{[]string{"not_exits_tensor"}, []string{"tensor1"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true},
897+
{"t1", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int{1}, []float32{1.1})}, false},
898+
{"t_blob", args{nil, nil, NewDag().TensorSet("a", TypeFloat32, []int{1}, []float32{1.1}).TensorSet("b", TypeFloat32, []int{1}, []float32{4.4}).ModelRun("test:DagRun:mymodel:1", []string{"a", "b"}, []string{"mul"}).TensorGet("mul", TensorContentTypeBlob)}, false},
899+
{"t_values", args{nil, nil, NewDag().TensorSet("mytensor", TypeFloat32, []int{1, 2}, []int{5, 10}).TensorGet("mytensor", TensorContentTypeValues)}, false},
900+
}
901+
for _, tt := range tests {
902+
t.Run(tt.name, func(t *testing.T) {
903+
c := createTestClient()
904+
results, err := c.DagRun(tt.args.loadKeys, tt.args.persistKeys, tt.args.dagCommandInterface)
905+
if (err != nil) != tt.wantErr {
906+
t.Errorf("DagRun() error = %v, wantErr %v", err, tt.wantErr)
907+
return
908+
}
909+
910+
for _, result := range results {
911+
ret, ok := result.(string)
912+
if ok {
913+
assert.Equal(t, "OK", ret)
914+
continue
915+
}
916+
values, ok := result.([]interface{})
917+
if ok {
918+
vs, _ := redis.Strings(values, nil)
919+
assert.True(t, len(vs) > 0)
920+
continue
921+
}
922+
blobs, ok := result.([]byte)
923+
if ok {
924+
assert.True(t, len(blobs) > 0)
925+
continue
926+
}
927+
t.Errorf("DagRun() error unsupported result")
928+
}
929+
})
930+
}
931+
}
932+
933+
func TestCommand_DagRunRO(t *testing.T) {
934+
c := createTestClient()
935+
err := c.TensorSet("persisted_tensor", TypeFloat32, []int{1, 2}, []float32{5, 10})
936+
assert.Nil(t, err)
937+
type args struct {
938+
loadKeys []string
939+
dagCommandInterface DagCommandInterface
940+
}
941+
tests := []struct {
942+
name string
943+
args args
944+
wantErr bool
945+
}{
946+
{"t_1", args{[]string{"persisted_tensor"}, NewDag().TensorGet("persisted_tensor", TensorContentTypeValues)}, false},
947+
{"t_2", args{nil, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10}).TensorSet("tensor2", TypeFloat32, []int{1, 2}, []int{5, 10})}, false},
948+
{"t_err1", args{[]string{"notnumber"}, NewDag().TensorSet("tensor1", TypeFloat32, []int{1, 2}, []int{5, 10})}, true},
949+
}
950+
for _, tt := range tests {
951+
t.Run(tt.name, func(t *testing.T) {
952+
c := createTestClient()
953+
results, err := c.DagRunRO(tt.args.loadKeys, tt.args.dagCommandInterface)
954+
if (err != nil) != tt.wantErr {
955+
t.Errorf("DagRunRO() error = %v, wantErr %v", err, tt.wantErr)
956+
return
957+
}
958+
959+
for _, result := range results {
960+
ret, ok := result.(string)
961+
if ok {
962+
assert.Equal(t, "OK", ret)
963+
continue
964+
}
965+
values, ok := result.([]interface{})
966+
if ok {
967+
vs, _ := redis.Strings(values, nil)
968+
assert.True(t, len(vs) > 0)
969+
continue
970+
}
971+
blobs, ok := result.([]byte)
972+
if ok {
973+
assert.True(t, len(blobs) > 0)
974+
continue
975+
}
976+
t.Errorf("DagRunRO() error unsupported result")
977+
}
978+
})
979+
}
980+
}

redisai/dag.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package redisai
2+
3+
import "github.com/gomodule/redigo/redis"
4+
5+
// DagCommandInterface is an interface that represents the skeleton of DAG supported commands
6+
// needed to map it to a RedisAI DAGRUN and DAGURN_RO commands
7+
type DagCommandInterface interface {
8+
TensorSet(keyName, dt string, dims []int, data interface{}) DagCommandInterface
9+
TensorGet(name, format string) DagCommandInterface
10+
ModelRun(name string, inputTensorNames, outputTensorNames []string) DagCommandInterface
11+
FlatArgs() (redis.Args, error)
12+
ParseReply(reply interface{}, err error) ([]interface{}, error)
13+
}
14+
15+
type Dag struct {
16+
commands []redis.Args
17+
}
18+
19+
func NewDag() *Dag {
20+
return &Dag{
21+
commands: make([]redis.Args, 0),
22+
}
23+
}
24+
25+
func (d *Dag) TensorSet(keyName, dt string, dims []int, data interface{}) DagCommandInterface {
26+
args := redis.Args{"AI.TENSORSET"}
27+
setFlatArgs, err := tensorSetFlatArgs(keyName, dt, dims, data)
28+
if err == nil {
29+
args = args.AddFlat(setFlatArgs)
30+
}
31+
d.commands = append(d.commands, args)
32+
return d
33+
}
34+
35+
func (d *Dag) TensorGet(name, format string) DagCommandInterface {
36+
d.commands = append(d.commands, redis.Args{"AI.TENSORGET", name, format})
37+
return d
38+
}
39+
40+
func (d *Dag) ModelRun(name string, inputTensorNames, outputTensorNames []string) DagCommandInterface {
41+
args := redis.Args{"AI.MODELRUN"}
42+
runFlatArgs := modelRunFlatArgs(name, inputTensorNames, outputTensorNames)
43+
args = args.AddFlat(runFlatArgs)
44+
d.commands = append(d.commands, args)
45+
return d
46+
}
47+
48+
func (d *Dag) FlatArgs() (redis.Args, error) {
49+
args := redis.Args{}
50+
for _, command := range d.commands {
51+
args = args.Add("|>")
52+
args = args.AddFlat(command)
53+
}
54+
return args, nil
55+
}
56+
57+
func (d *Dag) ParseReply(reply interface{}, err error) ([]interface{}, error) {
58+
return redis.Values(reply, err)
59+
}

0 commit comments

Comments
 (0)