Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
#include "torch/csrc/autograd/record_function_ops.h"
#include "torch/csrc/autograd/record_function_ops.h"

#ifdef USE_ROCM
#include "fbgemm_gpu/rocm/cdna_guard.h"
#endif

{%- if ssd %}
#include "pt2_arg_utils_ssd.h"
{%- else %}
Expand Down Expand Up @@ -1067,13 +1071,24 @@ static torch::autograd::variable_list backward(
(not ssd) and
(not is_index_select) and
(not dense) %}
int32_t total_L = indices.numel();
const int32_t total_L = indices.numel();
const auto T = weights_offsets.sym_numel();
auto total_B = (offsets.size(0) - 1);
const auto B = total_B / T;
const bool cached = weights_uvm.numel() > 0 || weights_lxu_cache.numel() > 0;
const bool use_hip_kernel = config::is_feature_enabled(
config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL);
const auto weights_type = weights_dev.scalar_type();
const auto grad_type = grad_outputs[0].scalar_type();
const bool supported_weights_type =
(weights_type == at::kFloat) || (weights_type == at::kHalf);
const bool supported_grad_type =
(grad_type == at::kFloat) || (grad_type == at::kHalf);
const bool same_precision = (weights_type == grad_type);
const static bool supported_platform = rocm::is_supported_cdna();
{%- for kDimSize in [64, 128, 160, 192, 256, 320] %}
if(!mixed_D && total_L / total_B > 1 && (max_D == {{ kDimSize }}))
{
if (use_hip_kernel && total_L / total_B > 1 && !mixed_D && !cached && supported_weights_type &&
supported_grad_type && same_precision && supported_platform &&
(max_D == {{ kDimSize }})) {
max_segment_length_per_warp = 16384;
}
{%- endfor %}
Expand Down