Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -772,22 +772,9 @@ def __init__( # noqa C901

self.weights_precision = weights_precision

if torch.cuda.is_available() and torch.version.hip:
# NOTE: It was discovered that FP16 cache precision caused a 500x
# slowdown in performance of split_embedding_nobag_backward_codegen_rowwise_adagrad_unweighted_kernel_warp_per_row_1
# kernel on ROCm, so to work around this, we fix cache precision to
# be FP32 always for the ROCm environment case.
#
# See:
# https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
cache_precision = SparseType.FP32
self.log("Override cache_precision=SparseType.FP32 on ROCm")
else:
# NOTE: The changes from D65865527 are retained here until we can
# test that the the hack also works for non-ROCm environments.
cache_precision = (
weights_precision if cache_precision is None else cache_precision
)
cache_precision = (
weights_precision if cache_precision is None else cache_precision
)

self.output_dtype: int = output_dtype.as_int()
assert (
Expand Down
9 changes: 6 additions & 3 deletions fbgemm_gpu/include/fbgemm_gpu/utils/float.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@ struct Half4 {

__device__ inline void store(at::Half* p) {
#ifdef USE_ROCM
*reinterpret_cast<unsigned int*>(p) = *reinterpret_cast<unsigned int*>(&a);
*reinterpret_cast<unsigned int*>(p + 2) =
*reinterpret_cast<unsigned int*>(&b);
const unsigned int lo = *reinterpret_cast<unsigned int*>(&a);
const unsigned int hi = *reinterpret_cast<unsigned int*>(&b);
const unsigned long long packed =
static_cast<unsigned long long>(lo) |
(static_cast<unsigned long long>(hi) << 32);
*reinterpret_cast<unsigned long long*>(p) = packed;
#else

#ifndef __HALF2_TO_UI
Expand Down