2424 @test testf (* , transpose (rand (elty, m, n)), rand (elty, m))
2525 @test testf (* , rand (elty, m, n)' , rand (elty, m))
2626 x = rand (elty, m)
27- A = rand (elty, m, m + 1 )
28- y = rand (elty, m )
27+ A = rand (elty, m, m + 1 )
28+ y = rand (elty, n )
2929 dx = CuArray (x)
3030 dA = CuArray (A)
3131 dy = CuArray (y)
@@ -44,6 +44,10 @@ k = 13
4444 dy = CUBLAS. gemv (' N' , dA, dx)
4545 hy = collect (dy)
4646 @test hy ≈ A * x
47+ dy = CuArray (y)
48+ dx = CUBLAS. gemv (elty <: Real ? ' T' : ' C' , alpha, dA, dy)
49+ hx = collect (dx)
50+ @test hx ≈ alpha * A' * y
4751 end
4852
4953 if CUBLAS. version () >= v " 11.9"
@@ -72,6 +76,16 @@ k = 13
7276 y[i] = alpha * A[i] * x[i] + beta * y[i]
7377 @test y[i] ≈ hy
7478 end
79+ dy = CuArray{elty, 1 }[]
80+ for i= 1 : length (A)
81+ push! (dy, CuArray (y[i]))
82+ end
83+ CUBLAS. gemv_batched! (elty <: Real ? ' T' : ' C' , alpha, dA, dy, beta, dx)
84+ for i in 1 : length (A)
85+ hx = collect (dx[i])
86+ x[i] = alpha * A[i]' * y[i] + beta * x[i]
87+ @test x[i] ≈ hx
88+ end
7589 end
7690 end
7791
@@ -92,11 +106,18 @@ k = 13
92106 dbad = CuArray (bad)
93107 @test_throws DimensionMismatch CUBLAS. gemv_strided_batched! (' N' , alpha, dA, dx, beta, dbad)
94108 CUBLAS. gemv_strided_batched! (' N' , alpha, dA, dx, beta, dy)
95- for i= 1 : size (A, 3 )
109+ for i in 1 : size (A, 3 )
96110 hy = collect (dy[:, i])
97111 y[:, i] = alpha * A[:, :, i] * x[:, i] + beta * y[:, i]
98112 @test y[:, i] ≈ hy
99113 end
114+ dy = CuArray (y)
115+ CUBLAS. gemv_strided_batched! (elty <: Real ? ' T' : ' C' , alpha, dA, dy, beta, dx)
116+ for i in 1 : size (A, 3 )
117+ hx = collect (dx[:, i])
118+ x[:, i] = alpha * A[:, :, i]' * y[:, i] + beta * x[:, i]
119+ @test x[:, i] ≈ hx
120+ end
100121 end
101122 end
102123
0 commit comments