@@ -868,3 +868,113 @@ func TestCommand_Info(t *testing.T) {
868868 ret , err = c .ResetStat ("notExits" )
869869 assert .NotNil (t , err )
870870}
871+
872+ func TestCommand_DagRun (t * testing.T ) {
873+ c := createTestClient ()
874+ keyModel1 := "test:DagRun:mymodel:1"
875+ data , err := ioutil .ReadFile ("./../tests/test_data/graph.pb" )
876+ if err != nil {
877+ t .Errorf ("Error preparing for Info(), while issuing ModelSet. error = %v" , err )
878+ return
879+ }
880+ err = c .ModelSet (keyModel1 , BackendTF , DeviceCPU , data , []string {"a" , "b" }, []string {"mul" })
881+ err = c .TensorSet ("persisted_tensor_1" , TypeFloat32 , []int {1 , 2 }, []float32 {5 , 10 })
882+ assert .Nil (t , err )
883+
884+ type args struct {
885+ loadKeys []string
886+ persistKeys []string
887+ dagCommandInterface DagCommandInterface
888+ }
889+ tests := []struct {
890+ name string
891+ args args
892+ wantErr bool
893+ }{
894+ {"t_wrong_number" , args {[]string {"notnumber" }, nil , NewDag ().TensorSet ("tensor1" , TypeFloat32 , []int {1 , 2 }, []int {5 , 10 })}, true },
895+ {"t_load" , args {[]string {"persisted_tensor_1" }, []string {"tensor1" }, NewDag ().TensorSet ("tensor1" , TypeFloat32 , []int {1 , 2 }, []int {5 , 10 })}, false },
896+ {"t_load_err" , args {[]string {"not_exits_tensor" }, []string {"tensor1" }, NewDag ().TensorSet ("tensor1" , TypeFloat32 , []int {1 , 2 }, []int {5 , 10 })}, true },
897+ {"t1" , args {nil , nil , NewDag ().TensorSet ("a" , TypeFloat32 , []int {1 }, []float32 {1.1 })}, false },
898+ {"t_blob" , args {nil , nil , NewDag ().TensorSet ("a" , TypeFloat32 , []int {1 }, []float32 {1.1 }).TensorSet ("b" , TypeFloat32 , []int {1 }, []float32 {4.4 }).ModelRun ("test:DagRun:mymodel:1" , []string {"a" , "b" }, []string {"mul" }).TensorGet ("mul" , TensorContentTypeBlob )}, false },
899+ {"t_values" , args {nil , nil , NewDag ().TensorSet ("mytensor" , TypeFloat32 , []int {1 , 2 }, []int {5 , 10 }).TensorGet ("mytensor" , TensorContentTypeValues )}, false },
900+ }
901+ for _ , tt := range tests {
902+ t .Run (tt .name , func (t * testing.T ) {
903+ c := createTestClient ()
904+ results , err := c .DagRun (tt .args .loadKeys , tt .args .persistKeys , tt .args .dagCommandInterface )
905+ if (err != nil ) != tt .wantErr {
906+ t .Errorf ("DagRun() error = %v, wantErr %v" , err , tt .wantErr )
907+ return
908+ }
909+
910+ for _ , result := range results {
911+ ret , ok := result .(string )
912+ if ok {
913+ assert .Equal (t , "OK" , ret )
914+ continue
915+ }
916+ values , ok := result .([]interface {})
917+ if ok {
918+ vs , _ := redis .Strings (values , nil )
919+ assert .True (t , len (vs ) > 0 )
920+ continue
921+ }
922+ blobs , ok := result .([]byte )
923+ if ok {
924+ assert .True (t , len (blobs ) > 0 )
925+ continue
926+ }
927+ t .Errorf ("DagRun() error unsupported result" )
928+ }
929+ })
930+ }
931+ }
932+
933+ func TestCommand_DagRunRO (t * testing.T ) {
934+ c := createTestClient ()
935+ err := c .TensorSet ("persisted_tensor" , TypeFloat32 , []int {1 , 2 }, []float32 {5 , 10 })
936+ assert .Nil (t , err )
937+ type args struct {
938+ loadKeys []string
939+ dagCommandInterface DagCommandInterface
940+ }
941+ tests := []struct {
942+ name string
943+ args args
944+ wantErr bool
945+ }{
946+ {"t_1" , args {[]string {"persisted_tensor" }, NewDag ().TensorGet ("persisted_tensor" , TensorContentTypeValues )}, false },
947+ {"t_2" , args {nil , NewDag ().TensorSet ("tensor1" , TypeFloat32 , []int {1 , 2 }, []int {5 , 10 }).TensorSet ("tensor2" , TypeFloat32 , []int {1 , 2 }, []int {5 , 10 })}, false },
948+ {"t_err1" , args {[]string {"notnumber" }, NewDag ().TensorSet ("tensor1" , TypeFloat32 , []int {1 , 2 }, []int {5 , 10 })}, true },
949+ }
950+ for _ , tt := range tests {
951+ t .Run (tt .name , func (t * testing.T ) {
952+ c := createTestClient ()
953+ results , err := c .DagRunRO (tt .args .loadKeys , tt .args .dagCommandInterface )
954+ if (err != nil ) != tt .wantErr {
955+ t .Errorf ("DagRunRO() error = %v, wantErr %v" , err , tt .wantErr )
956+ return
957+ }
958+
959+ for _ , result := range results {
960+ ret , ok := result .(string )
961+ if ok {
962+ assert .Equal (t , "OK" , ret )
963+ continue
964+ }
965+ values , ok := result .([]interface {})
966+ if ok {
967+ vs , _ := redis .Strings (values , nil )
968+ assert .True (t , len (vs ) > 0 )
969+ continue
970+ }
971+ blobs , ok := result .([]byte )
972+ if ok {
973+ assert .True (t , len (blobs ) > 0 )
974+ continue
975+ }
976+ t .Errorf ("DagRunRO() error unsupported result" )
977+ }
978+ })
979+ }
980+ }
0 commit comments