@@ -59,9 +59,36 @@ struct gemm_emul_tinysq
5959
6060
6161
62+ struct gemm_emul_large_mp_helper
63+ {
64+ template <typename eT>
65+ arma_hot
66+ inline
67+ static
68+ void
69+ copy_row (eT* out_mem, const Mat<eT>& in, const uword row)
70+ {
71+ const uword n_rows = in.n_rows ;
72+ const uword n_cols = in.n_cols ;
73+
74+ const eT* in_mem_row = in.memptr () + row;
75+
76+ for (uword i=0 ; i < n_cols; ++i)
77+ {
78+ out_mem[i] = (*in_mem_row);
79+
80+ in_mem_row += n_rows;
81+ }
82+ }
83+ };
84+
85+
86+
87+ #if defined(ARMA_USE_OPENMP)
6288// ! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes)
89+ // ! parallelised version
6390template <const bool do_trans_A=false , const bool do_trans_B=false , const bool use_alpha=false , const bool use_beta=false >
64- struct gemm_emul_large
91+ struct gemm_emul_large_mp
6592 {
6693 template <typename eT, typename TA, typename TB>
6794 arma_hot
@@ -78,13 +105,151 @@ struct gemm_emul_large
78105 )
79106 {
80107 arma_debug_sigprint ();
108+
109+ const uword A_n_rows = A.n_rows ;
110+ const uword A_n_cols = A.n_cols ;
111+
112+ const uword B_n_rows = B.n_rows ;
113+ const uword B_n_cols = B.n_cols ;
114+
115+ if ( (do_trans_A == false ) && (do_trans_B == false ) )
116+ {
117+ const uword n_threads = uword (mp_thread_limit::get ());
118+
119+ podarray<eT> tmp (A_n_cols * n_threads, arma_nozeros_indicator ());
120+
121+ eT* tmp_mem = tmp.memptr ();
122+
123+ #pragma omp parallel for schedule(static) num_threads(int(n_threads))
124+ for (uword row_A=0 ; row_A < A_n_rows; ++row_A)
125+ {
126+ const uword thread_id = uword (omp_get_thread_num ());
127+
128+ eT* A_rowdata = tmp_mem + (A_n_cols * thread_id);
129+
130+ gemm_emul_large_mp_helper::copy_row (A_rowdata, A, row_A);
131+
132+ for (uword col_B=0 ; col_B < B_n_cols; ++col_B)
133+ {
134+ const eT acc = op_dot::direct_dot (B_n_rows, A_rowdata, B.colptr (col_B));
135+
136+ if ( (use_alpha == false ) && (use_beta == false ) ) { C.at (row_A,col_B) = acc; }
137+ else if ( (use_alpha == true ) && (use_beta == false ) ) { C.at (row_A,col_B) = alpha*acc; }
138+ else if ( (use_alpha == false ) && (use_beta == true ) ) { C.at (row_A,col_B) = acc + beta*C.at (row_A,col_B); }
139+ else if ( (use_alpha == true ) && (use_beta == true ) ) { C.at (row_A,col_B) = alpha*acc + beta*C.at (row_A,col_B); }
140+ }
141+ }
142+ }
143+ else
144+ if ( (do_trans_A == true ) && (do_trans_B == false ) )
145+ {
146+ const int n_threads = mp_thread_limit::get ();
147+
148+ #pragma omp parallel for schedule(static) num_threads(n_threads)
149+ for (uword col_A=0 ; col_A < A_n_cols; ++col_A)
150+ {
151+ // col_A is interpreted as row_A when storing the results in matrix C
152+
153+ const eT* A_coldata = A.colptr (col_A);
154+
155+ for (uword col_B=0 ; col_B < B_n_cols; ++col_B)
156+ {
157+ const eT acc = op_dot::direct_dot (B_n_rows, A_coldata, B.colptr (col_B));
158+
159+ if ( (use_alpha == false ) && (use_beta == false ) ) { C.at (col_A,col_B) = acc; }
160+ else if ( (use_alpha == true ) && (use_beta == false ) ) { C.at (col_A,col_B) = alpha*acc; }
161+ else if ( (use_alpha == false ) && (use_beta == true ) ) { C.at (col_A,col_B) = acc + beta*C.at (col_A,col_B); }
162+ else if ( (use_alpha == true ) && (use_beta == true ) ) { C.at (col_A,col_B) = alpha*acc + beta*C.at (col_A,col_B); }
163+ }
164+ }
165+ }
166+ else
167+ if ( (do_trans_A == false ) && (do_trans_B == true ) )
168+ {
169+ Mat<eT> BB;
170+ op_strans::apply_mat_noalias (BB, B);
171+
172+ gemm_emul_large_mp<false , false , use_alpha, use_beta>::apply (C, A, BB, alpha, beta);
173+ }
174+ else
175+ if ( (do_trans_A == true ) && (do_trans_B == true ) )
176+ {
177+ // using trans(A)*trans(B) = trans(B*A) equivalency; assuming no hermitian transpose
178+
179+ const uword n_threads = uword (mp_thread_limit::get ());
180+
181+ podarray<eT> tmp (B_n_cols * n_threads, arma_nozeros_indicator ());
182+
183+ eT* tmp_mem = tmp.memptr ();
184+
185+ #pragma omp parallel for schedule(static) num_threads(int(n_threads))
186+ for (uword row_B=0 ; row_B < B_n_rows; ++row_B)
187+ {
188+ const uword thread_id = uword (omp_get_thread_num ());
189+
190+ eT* B_rowdata = tmp_mem + (B_n_cols * thread_id);
191+
192+ gemm_emul_large_mp_helper::copy_row (B_rowdata, B, row_B);
193+
194+ for (uword col_A=0 ; col_A < A_n_cols; ++col_A)
195+ {
196+ const eT acc = op_dot::direct_dot (A_n_rows, B_rowdata, A.colptr (col_A));
197+
198+ if ( (use_alpha == false ) && (use_beta == false ) ) { C.at (col_A,row_B) = acc; }
199+ else if ( (use_alpha == true ) && (use_beta == false ) ) { C.at (col_A,row_B) = alpha*acc; }
200+ else if ( (use_alpha == false ) && (use_beta == true ) ) { C.at (col_A,row_B) = acc + beta*C.at (col_A,row_B); }
201+ else if ( (use_alpha == true ) && (use_beta == true ) ) { C.at (col_A,row_B) = alpha*acc + beta*C.at (col_A,row_B); }
202+ }
203+ }
204+ }
205+ }
206+
207+ };
208+ #endif
209+
210+
81211
212+ // ! emulation of gemm(), for non-complex matrices only, as it assumes only simple transposes (ie. doesn't do hermitian transposes)
213+ template <const bool do_trans_A=false , const bool do_trans_B=false , const bool use_alpha=false , const bool use_beta=false >
214+ struct gemm_emul_large
215+ {
216+ template <typename eT, typename TA, typename TB>
217+ arma_hot
218+ inline
219+ static
220+ void
221+ apply
222+ (
223+ Mat<eT>& C,
224+ const TA& A,
225+ const TB& B,
226+ const eT alpha = eT(1 ),
227+ const eT beta = eT(0 )
228+ )
229+ {
230+ arma_debug_sigprint ();
231+
82232 const uword A_n_rows = A.n_rows ;
83233 const uword A_n_cols = A.n_cols ;
84234
85235 const uword B_n_rows = B.n_rows ;
86236 const uword B_n_cols = B.n_cols ;
87237
238+ #if defined(ARMA_USE_OPENMP)
239+ {
240+ // TODO: replace with more sophisticated threshold mechanism
241+
242+ constexpr uword threshold = uword (30 );
243+
244+ if ( (A_n_rows >= threshold) && (A_n_cols >= threshold) && (B_n_rows >= threshold) && (B_n_cols >= threshold) && (mp_thread_limit::in_parallel () == false ) )
245+ {
246+ gemm_emul_large_mp<do_trans_A, do_trans_B, use_alpha, use_beta>::apply (C,A,B,alpha,beta);
247+
248+ return ;
249+ }
250+ }
251+ #endif
252+
88253 if ( (do_trans_A == false ) && (do_trans_B == false ) )
89254 {
90255 arma_aligned podarray<eT> tmp (A_n_cols);
0 commit comments