Skip to content

Commit 46fd06e

Browse files
[add] updated description of TensorInterface. Improved helper methods to make it more general (#6)
1 parent 6cef5bb commit 46fd06e

File tree

3 files changed

+56
-20
lines changed

3 files changed

+56
-20
lines changed

redisai/commands_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ func TestCommand_TensorSet(t *testing.T) {
9191
}
9292

9393
func TestCommand_FullFromTensor(t *testing.T) {
94-
tensor := implementations.NewAiTensorWithTypeShape(TypeFloat32, []int{1})
94+
tensor := implementations.NewAiTensorWithShape([]int{1})
9595
tensor.SetData([]float32{1.0})
9696
client := createTestClient()
9797
err := client.TensorSetFromTensor("tensor1", tensor)

redisai/implementations/AITensor.go

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
package implementations
22

3+
import "reflect"
4+
35
// TensorInterface is an interface that represents the skeleton of a tensor ( n-dimensional array of numerical data )
46
// needed to map it to a RedisAI Model with the proper operations
57
type AiTensor struct {
6-
typestr string
7-
88
// the size - in each dimension - of the tensor.
99
shape []int
1010

1111
data interface{}
1212
}
1313

14+
func (t *AiTensor) Dtype() reflect.Type {
15+
return reflect.TypeOf(t.data)
16+
}
17+
1418
func NewAiTensor() *AiTensor {
1519
return &AiTensor{}
1620
}
@@ -27,14 +31,6 @@ func (t *AiTensor) Len() int {
2731
return result
2832
}
2933

30-
func (t *AiTensor) TypeString() string {
31-
return t.typestr
32-
}
33-
34-
func (t *AiTensor) SetTypeString(typestr string) {
35-
t.typestr = typestr
36-
}
37-
3834
func (m *AiTensor) Shape() []int {
3935
return m.shape
4036
}
@@ -43,12 +39,12 @@ func (m *AiTensor) SetShape(shape []int) {
4339
m.shape = shape
4440
}
4541

46-
func NewAiTensorWithTypeShape(typestr string, shape []int) *AiTensor {
47-
return &AiTensor{typestr: typestr, shape: shape}
42+
func NewAiTensorWithShape(shape []int) *AiTensor {
43+
return &AiTensor{shape: shape}
4844
}
4945

5046
func NewAiTensorWithData(typestr string, shape []int, data interface{}) *AiTensor {
51-
tensor := NewAiTensorWithTypeShape(typestr, shape)
47+
tensor := NewAiTensorWithShape(shape)
5248
tensor.SetData(data)
5349
return tensor
5450
}

redisai/tensor.go

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,50 @@ type TensorInterface interface {
2222
// Len returns the number of elements in the tensor.
2323
Len() int
2424

25-
TypeString() string
26-
SetTypeString(typestr string)
25+
Dtype() reflect.Type
2726

27+
// Data returns the underlying tensor data
2828
Data() interface{}
2929
SetData(interface{})
3030
}
3131

32+
func TensorGetTypeStrFromType(dtype reflect.Type) (typestr string, err error) {
33+
switch dtype {
34+
case reflect.TypeOf(([]uint8)(nil)):
35+
typestr = TypeUint8
36+
case reflect.TypeOf(([]byte)(nil)):
37+
typestr = TypeUint8
38+
case reflect.TypeOf(([]int)(nil)):
39+
typestr = TypeInt32
40+
case reflect.TypeOf(([]int8)(nil)):
41+
typestr = TypeInt8
42+
case reflect.TypeOf(([]int16)(nil)):
43+
typestr = TypeInt16
44+
case reflect.TypeOf(([]int32)(nil)):
45+
typestr = TypeInt32
46+
case reflect.TypeOf(([]int64)(nil)):
47+
typestr = TypeInt64
48+
case reflect.TypeOf(([]uint)(nil)):
49+
typestr = TypeUint8
50+
case reflect.TypeOf(([]uint16)(nil)):
51+
typestr = TypeUint16
52+
case reflect.TypeOf(([]float32)(nil)):
53+
typestr = TypeFloat32
54+
case reflect.TypeOf(([]float64)(nil)):
55+
typestr = TypeFloat64
56+
case reflect.TypeOf(([]uint32)(nil)):
57+
fallthrough
58+
// unsupported data type
59+
case reflect.TypeOf(([]uint64)(nil)):
60+
fallthrough
61+
// unsupported data type
62+
63+
default:
64+
err = fmt.Errorf("redisai Tensor does not support the following type %v", dtype)
65+
}
66+
return
67+
}
68+
3269
func tensorSetFlatArgs(name string, dt string, dims []int, data interface{}) (redis.Args, error) {
3370
args := redis.Args{}
3471
var err error = nil
@@ -74,13 +111,16 @@ func tensorSetFlatArgs(name string, dt string, dims []int, data interface{}) (re
74111
return args, err
75112
}
76113

77-
func tensorSetInterfaceArgs(keyName string, tensorInterface TensorInterface) (redis.Args, error) {
78-
return tensorSetFlatArgs(keyName, tensorInterface.TypeString(), tensorInterface.Shape(), tensorInterface.Data())
114+
func tensorSetInterfaceArgs(keyName string, tensorInterface TensorInterface) (args redis.Args, err error) {
115+
typestr, err := TensorGetTypeStrFromType(tensorInterface.Dtype())
116+
if err != nil {
117+
return
118+
}
119+
return tensorSetFlatArgs(keyName, typestr, tensorInterface.Shape(), tensorInterface.Data())
79120
}
80121

81122
func tensorGetParseToInterface(reply interface{}, tensor TensorInterface) (err error) {
82-
err, dtype, shape, data := ProcessTensorGetReply(reply, err)
83-
tensor.SetTypeString(dtype)
123+
err, _, shape, data := ProcessTensorGetReply(reply, err)
84124
tensor.SetShape(shape)
85125
tensor.SetData(data)
86126
return

0 commit comments

Comments
 (0)