@@ -69,12 +69,8 @@ static void beta_op(float *x, BLASLONG n, FLOAT beta) {
6969 x += 4 ;
7070 }
7171
72- if (rest_n & 3 ) {
73- x [0 ] *= beta ;
74- if ((rest_n & 3 ) > 1 )
75- x [1 ] *= beta ;
76- if ((rest_n & 3 ) > 2 )
77- x [2 ] *= beta ;
72+ for (BLASLONG i = 0 ; i < (rest_n & 3 ); i ++ ) {
73+ x [i ] *= beta ;
7874 }
7975 }
8076 return ;
@@ -88,7 +84,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
8884
8985 bfloat16x8_t a0 , a1 , a2 , a3 , a4 , a5 , a6 , a7 ;
9086 bfloat16x8_t t0 , t1 , t2 , t3 , t4 , t5 , t6 , t7 ;
87+
9188 bfloat16x8_t x_vec ;
89+ bfloat16x4_t x_vecx4 ;
90+
9291 float32x4_t y1_vec , y2_vec ;
9392 float32x4_t fp32_low , fp32_high ;
9493
@@ -106,7 +105,7 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
106105
107106 if (incx == 1 && incy == 1 ) {
108107 if (beta != 1 ) {
109- beta_op (y , n , beta );
108+ beta_op (y , m , beta );
110109 }
111110
112111 for (i = 0 ; i < n / 8 ; i ++ ) {
@@ -290,12 +289,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
290289
291290 a_ptr += 4 * lda ;
292291
293- bfloat16x4_t x_vecx4 = vld1_bf16 (x_ptr );
292+ x_vecx4 = vld1_bf16 (x_ptr );
294293 if (alpha != 1 ) {
295- x_vec = vcombine_bf16 (x_vecx4 , bf16_zero );
296- fp32_low = vreinterpretq_f32_u16 (
297- vzip1q_u16 (vreinterpretq_u16_bf16 (bf16_zero_q ),
298- vreinterpretq_u16_bf16 (x_vec )));
294+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
299295 fp32_low = vmulq_n_f32 (fp32_low , alpha );
300296 x_vecx4 = vcvt_bf16_f32 (fp32_low );
301297 }
@@ -348,15 +344,11 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
348344
349345 y1_vec = vld1q_f32 (y_ptr );
350346
351- a0 = vcombine_bf16 (a0x4 , bf16_zero );
352- a1 = vcombine_bf16 (a1x4 , bf16_zero );
353- a2 = vcombine_bf16 (a2x4 , bf16_zero );
354- a3 = vcombine_bf16 (a3x4 , bf16_zero );
347+ a0 = vcombine_bf16 (a0x4 , a2x4 );
348+ a1 = vcombine_bf16 (a1x4 , a3x4 );
355349
356- t0 = vreinterpretq_bf16_u16 (
357- vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
358- t1 = vreinterpretq_bf16_u16 (
359- vzip1q_u16 (vreinterpretq_u16_bf16 (a2 ), vreinterpretq_u16_bf16 (a3 )));
350+ t0 = vreinterpretq_bf16_u16 (vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
351+ t1 = vreinterpretq_bf16_u16 (vzip2q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
360352
361353 y1_vec = vbfmlalbq_lane_f32 (y1_vec , t0 , x_vecx4 , 0 );
362354 y1_vec = vbfmlaltq_lane_f32 (y1_vec , t0 , x_vecx4 , 1 );
@@ -374,10 +366,12 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
374366 }
375367
376368 if (rest_m ) {
377- x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
378- x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
379- x2 = alpha * vcvtah_f32_bf16 (x_ptr [2 ]);
380- x3 = alpha * vcvtah_f32_bf16 (x_ptr [3 ]);
369+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
370+
371+ x0 = vgetq_lane_f32 (fp32_low , 0 );
372+ x1 = vgetq_lane_f32 (fp32_low , 1 );
373+ x2 = vgetq_lane_f32 (fp32_low , 2 );
374+ x3 = vgetq_lane_f32 (fp32_low , 3 );
381375
382376 for (BLASLONG j = 0 ; j < rest_m ; j ++ ) {
383377 y_ptr [j ] += x0 * vcvtah_f32_bf16 (a_ptr0 [j ]);
@@ -396,18 +390,13 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
396390
397391 a_ptr += 2 * lda ;
398392
399- bfloat16_t tmp_buffer [4 ];
400- memset ((void * )tmp_buffer , 0 , sizeof (bfloat16_t ));
401-
402- tmp_buffer [0 ] = x_ptr [0 ];
403- tmp_buffer [1 ] = x_ptr [1 ];
393+ x_vecx4 = vreinterpret_bf16_u16 (vzip1_u16 (
394+ vreinterpret_u16_bf16 (vdup_n_bf16 (x_ptr [0 ])),
395+ vreinterpret_u16_bf16 (vdup_n_bf16 (x_ptr [1 ]))
396+ ));
404397
405- bfloat16x4_t x_vecx4 = vld1_bf16 (tmp_buffer );
406398 if (alpha != 1 ) {
407- x_vec = vcombine_bf16 (x_vecx4 , bf16_zero );
408- fp32_low = vreinterpretq_f32_u16 (
409- vzip1q_u16 (vreinterpretq_u16_bf16 (bf16_zero_q ),
410- vreinterpretq_u16_bf16 (x_vec )));
399+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
411400 fp32_low = vmulq_n_f32 (fp32_low , alpha );
412401 x_vecx4 = vcvt_bf16_f32 (fp32_low );
413402 }
@@ -422,14 +411,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
422411
423412 t0 = vreinterpretq_bf16_u16 (
424413 vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
425- t4 = vreinterpretq_bf16_u16 (
414+ t1 = vreinterpretq_bf16_u16 (
426415 vzip2q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
427416
428417 y1_vec = vbfmlalbq_lane_f32 (y1_vec , t0 , x_vecx4 , 0 );
429418 y1_vec = vbfmlaltq_lane_f32 (y1_vec , t0 , x_vecx4 , 1 );
430419
431- y2_vec = vbfmlalbq_lane_f32 (y2_vec , t4 , x_vecx4 , 0 );
432- y2_vec = vbfmlaltq_lane_f32 (y2_vec , t4 , x_vecx4 , 1 );
420+ y2_vec = vbfmlalbq_lane_f32 (y2_vec , t1 , x_vecx4 , 0 );
421+ y2_vec = vbfmlaltq_lane_f32 (y2_vec , t1 , x_vecx4 , 1 );
433422
434423 vst1q_f32 (y_ptr , y1_vec );
435424 vst1q_f32 (y_ptr + 4 , y2_vec );
@@ -449,29 +438,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
449438 a0 = vcombine_bf16 (a0x4 , bf16_zero );
450439 a1 = vcombine_bf16 (a1x4 , bf16_zero );
451440
452- t0 = vreinterpretq_bf16_u16 (
453- vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
454- t1 = vreinterpretq_bf16_u16 (
455- vzip1q_u16 (vreinterpretq_u16_bf16 (a2 ), vreinterpretq_u16_bf16 (a3 )));
441+ t0 = vreinterpretq_bf16_u16 (vzip1q_u16 (vreinterpretq_u16_bf16 (a0 ), vreinterpretq_u16_bf16 (a1 )));
456442
457443 y1_vec = vbfmlalbq_lane_f32 (y1_vec , t0 , x_vecx4 , 0 );
458444 y1_vec = vbfmlaltq_lane_f32 (y1_vec , t0 , x_vecx4 , 1 );
459- y1_vec = vbfmlalbq_lane_f32 (y1_vec , t1 , x_vecx4 , 2 );
460- y1_vec = vbfmlaltq_lane_f32 (y1_vec , t1 , x_vecx4 , 3 );
461445
462446 vst1q_f32 (y_ptr , y1_vec );
463447
464448 a_ptr0 += 4 ;
465449 a_ptr1 += 4 ;
466- a_ptr2 += 4 ;
467- a_ptr3 += 4 ;
468450
469451 y_ptr += 4 ;
470452 }
471453
472454 if (m & 2 ) {
473- x0 = alpha * (vcvtah_f32_bf16 (x_ptr [0 ]));
474- x1 = alpha * (vcvtah_f32_bf16 (x_ptr [1 ]));
455+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
456+ x0 = vgetq_lane_f32 (fp32_low , 0 );
457+ x1 = vgetq_lane_f32 (fp32_low , 1 );
458+
475459
476460 y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
477461 y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
@@ -485,8 +469,9 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, bfloat16 *a, BLASLONG lda,
485469 }
486470
487471 if (m & 1 ) {
488- x0 = alpha * vcvtah_f32_bf16 (x_ptr [0 ]);
489- x1 = alpha * vcvtah_f32_bf16 (x_ptr [1 ]);
472+ fp32_low = vcvt_f32_bf16 (x_vecx4 );
473+ x0 = vgetq_lane_f32 (fp32_low , 0 );
474+ x1 = vgetq_lane_f32 (fp32_low , 1 );
490475
491476 y_ptr [0 ] += x0 * vcvtah_f32_bf16 (a_ptr0 [0 ]);
492477 y_ptr [0 ] += x1 * vcvtah_f32_bf16 (a_ptr1 [0 ]);
0 commit comments