diff --git a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp index 4660e6ad0f..11b57285fc 100644 --- a/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp +++ b/fbgemm_gpu/codegen/training/pt2/embedding_split_host_pt2_autograd_template.cpp @@ -42,6 +42,10 @@ #include "torch/csrc/autograd/record_function_ops.h" #include "torch/csrc/autograd/record_function_ops.h" +#ifdef USE_ROCM +#include "fbgemm_gpu/rocm/cdna_guard.h" +#endif + {%- if ssd %} #include "pt2_arg_utils_ssd.h" {%- else %} @@ -1067,13 +1071,24 @@ static torch::autograd::variable_list backward( (not ssd) and (not is_index_select) and (not dense) %} - int32_t total_L = indices.numel(); + const int32_t total_L = indices.numel(); const auto T = weights_offsets.sym_numel(); auto total_B = (offsets.size(0) - 1); - const auto B = total_B / T; + const bool cached = weights_uvm.numel() > 0 || weights_lxu_cache.numel() > 0; + const bool use_hip_kernel = config::is_feature_enabled( + config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); + const auto weights_type = weights_dev.scalar_type(); + const auto grad_type = grad_outputs[0].scalar_type(); + const bool supported_weights_type = + (weights_type == at::kFloat) || (weights_type == at::kHalf); + const bool supported_grad_type = + (grad_type == at::kFloat) || (grad_type == at::kHalf); + const bool same_precision = (weights_type == grad_type); + const static bool supported_platform = rocm::is_supported_cdna(); {%- for kDimSize in [64, 128, 160, 192, 256, 320] %} - if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }})) - { + if (use_hip_kernel && total_L / total_B > 1 && !mixed_D && !cached && supported_weights_type && + supported_grad_type && same_precision && supported_platform && + (max_D == {{ kDimSize }})) { max_segment_length_per_warp = 16384; } {%- endfor %}