Skip to content

Commit 2f9048e

Browse files
authored
skip bf16 unit tests on CPUs that lack the required instruction (#3385)
* skip bf16 unit tests on CPUs that lack the required instruction * fix flake * refine code * Skip test_nocast_since_shared and test_noprepack_since_shared if not supported; add fp32 case for noprepack ut * refine code * revert test_mha.py changes * pass test_mha.py * remove unused code * pass test_shared_param.py * revert logic of selecting fp16 * fix flake
1 parent ae09c58 commit 2f9048e

13 files changed

+360
-146
lines changed

tests/cpu/test_distributed_merged_emb.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,10 @@ def env2int(env_list, default=-1):
8787
)
8888
for i in range(NUM_TABLE)
8989
]
90-
for dtype in [
91-
torch.bfloat16,
92-
torch.float32,
93-
torch.float64,
94-
]:
90+
dtypes = [torch.float32, torch.float64]
91+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
92+
dtypes.append(torch.bfloat16)
93+
for dtype in dtypes:
9594
for NUM_DIM in [64, 65, 128, 256]:
9695
emb_list = EmbeddingBagList(
9796
NUM_TABLE,

tests/cpu/test_fake_tensor.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,19 @@ def test_conv_inference(self):
109109
channels_last = torch.channels_last
110110
elif dim == 3:
111111
channels_last = torch.channels_last_3d
112+
dtypes = [
113+
torch.float32,
114+
]
115+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
116+
dtypes.append(torch.bfloat16)
112117
if dim == 1:
113118
options = itertools.product(
114119
[True, False],
115120
[1, 2],
116121
[1, 4],
117122
[True, False],
118123
[torch.contiguous_format],
119-
[torch.float32, torch.bfloat16],
124+
dtypes,
120125
)
121126
else:
122127
options = itertools.product(
@@ -125,7 +130,7 @@ def test_conv_inference(self):
125130
[1, 4],
126131
[True, False],
127132
[torch.contiguous_format, channels_last],
128-
[torch.float32, torch.bfloat16],
133+
dtypes,
129134
)
130135
for (
131136
bias,
@@ -182,12 +187,17 @@ def test_linear_inference(self):
182187
in_features = torch.randint(3, 10, (1,)).item()
183188

184189
input_shapes = [(8, in_features), (2, 4, in_features), (2, 2, 2, in_features)]
190+
dtypes = [
191+
torch.float32,
192+
]
193+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
194+
dtypes.append(torch.bfloat16)
185195
options = itertools.product(
186196
[True, False],
187197
input_shapes,
188198
[True, False],
189199
[True, False],
190-
[torch.float32, torch.bfloat16],
200+
dtypes,
191201
)
192202
for bias, x_shape, feed_sample_input, auto_kernel_selection, dtype in options:
193203
x = torch.randn(x_shape, dtype=torch.float32)
@@ -235,6 +245,11 @@ def test_deconv_inference(self):
235245
input_channel_per_group = 15
236246
output_channel_per_group = 3
237247
kernel_size = 3
248+
dtypes = [
249+
torch.float32,
250+
]
251+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
252+
dtypes.append(torch.bfloat16)
238253
options = itertools.product(
239254
[True, False],
240255
[1, 2],
@@ -243,7 +258,7 @@ def test_deconv_inference(self):
243258
[1, 2],
244259
[True, False],
245260
[torch.contiguous_format, channels_last],
246-
[torch.float32, torch.bfloat16],
261+
dtypes,
247262
)
248263
for (
249264
bias,

tests/cpu/test_fx_optimization.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,11 @@ def test_concat_linear(self):
7676
_bias = [True, False]
7777
_inplace = [True, False]
7878
_in_feature = [16, 129]
79-
_dtype = [torch.float, torch.bfloat16]
79+
_dtype = [
80+
torch.float,
81+
]
82+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
83+
_dtype.append(torch.bfloat16)
8084
options = itertools.product(_bias, _inplace, _in_feature, _dtype)
8185
for bias, inplace, in_feature, dtype in options:
8286
x = torch.randn(100, in_feature, dtype=dtype)
@@ -126,7 +130,12 @@ def test_automatically_apply_concat_linear_with_ipex_optimize(self):
126130
config = AutoConfig.from_pretrained(loc + "/bert-base-config.json")
127131
base_model = AutoModelForCausalLM.from_config(config).eval()
128132
inputs = torch.load(loc + "/bert-inputs.pt", weights_only=False)
129-
for dtype in [torch.float, torch.bfloat16]:
133+
dtypes = [
134+
torch.float,
135+
]
136+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
137+
dtypes.append(torch.bfloat16)
138+
for dtype in dtypes:
130139
for inplace in [True, False]:
131140
model = copy.deepcopy(base_model)
132141
auto_cast = dtype == torch.bfloat16
@@ -191,7 +200,12 @@ def check_unet_concated(model):
191200
torch.tensor(921),
192201
torch.randn(2, 77, 768),
193202
)
194-
for dtype in [torch.float, torch.bfloat16]:
203+
dtypes = [
204+
torch.float,
205+
]
206+
if torch.ops.mkldnn._is_mkldnn_bf16_supported():
207+
dtypes.append(torch.bfloat16)
208+
for dtype in dtypes:
195209
for inplace in [True, False]:
196210
model1 = copy.deepcopy(base_model)
197211
model2 = copy.deepcopy(base_model)

tests/cpu/test_graph_capture.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def test_inference_graph_mode_torchdynamo(self):
101101
y2 = model(x)
102102
self.assertEqual(y1, y2)
103103

104+
@unittest.skipIf(
105+
not torch.ops.mkldnn._is_mkldnn_bf16_supported(), "not supported bf16"
106+
)
104107
def test_inference_graph_mode_jit_autocast(self):
105108
model = Conv_Bn_Relu().to(memory_format=torch.channels_last).eval()
106109
x = torch.randn(3, 6, 10, 10).to(memory_format=torch.channels_last)
@@ -113,6 +116,9 @@ def test_inference_graph_mode_jit_autocast(self):
113116
self.assertEqual(y1, y2_bf16, prec=0.01)
114117
self.assertTrue(y2_bf16.dtype == torch.bfloat16)
115118

119+
@unittest.skipIf(
120+
not torch.ops.mkldnn._is_mkldnn_bf16_supported(), "not supported bf16"
121+
)
116122
def test_inference_graph_mode_torchdynamo_autocast(self):
117123
model = Conv_IF_Relu().to(memory_format=torch.channels_last).eval()
118124
x = torch.randn(3, 6, 10, 10).to(memory_format=torch.channels_last)
@@ -341,6 +347,9 @@ def test_throughput_benchmark_graph_mode_torchdynamo(self):
341347
y = model(x)
342348
self.assertEqual(y, y_bench)
343349

350+
@unittest.skipIf(
351+
not torch.ops.mkldnn._is_mkldnn_bf16_supported(), "not supported bf16"
352+
)
344353
def test_throughput_benchmark_graph_mode_jit_autocast(self):
345354
model = Conv_Bn_Relu().to(memory_format=torch.channels_last)
346355
model.eval()
@@ -360,6 +369,9 @@ def test_throughput_benchmark_graph_mode_jit_autocast(self):
360369
self.assertEqual(y, y_bench)
361370
self.assertTrue(y_bench.dtype == torch.bfloat16)
362371

372+
@unittest.skipIf(
373+
not torch.ops.mkldnn._is_mkldnn_bf16_supported(), "not supported bf16"
374+
)
363375
def test_throughput_benchmark_graph_mode_torchdynamo_autocast(self):
364376
model = Conv_IF_Relu().to(memory_format=torch.channels_last)
365377
model.eval()
@@ -394,6 +406,9 @@ def test_resnet50(self):
394406
y = model(data)
395407
self.assertTrue(y.dtype == torch.float32)
396408

409+
@unittest.skipIf(
410+
not torch.ops.mkldnn._is_mkldnn_bf16_supported(), "not supported bf16"
411+
)
397412
@skipIfNoTorchVision
398413
def test_resnet50_autocast(self):
399414
model = torchvision.models.resnet50(pretrained=False)
@@ -441,6 +456,9 @@ def test_training_graph_mode_fallback(self):
441456
self.assertEqual(y1, y2)
442457
self.assertEqual(x1.grad, x2.grad)
443458

459+
@unittest.skipIf(
460+
not torch.ops.mkldnn._is_mkldnn_bf16_supported(), "not supported bf16"
461+
)
444462
def test_training_graph_mode_jit_autocast(self):
445463
model = Conv_Bn_Relu().to(memory_format=torch.channels_last).train()
446464
x = torch.randn(3, 6, 10, 10).to(memory_format=torch.channels_last)
@@ -461,6 +479,9 @@ def test_training_graph_mode_jit_autocast(self):
461479
self.assertEqual(x1.grad, x2.grad, prec=0.01)
462480
self.assertTrue(y2.dtype == torch.bfloat16)
463481

482+
@unittest.skipIf(
483+
not torch.ops.mkldnn._is_mkldnn_bf16_supported(), "not supported bf16"
484+
)
464485
def test_training_graph_mode_fallback_autocast(self):
465486
model = Conv_IF_Relu().to(memory_format=torch.channels_last).train()
466487
x = torch.randn(3, 6, 10, 10).to(memory_format=torch.channels_last)

tests/cpu/test_ipex_llm_module.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,11 @@ def test_linearfusion_args0(self):
308308
ipex.llm.modules.LinearRelu,
309309
ipex.llm.modules.Linear2SiluMul,
310310
]
311-
dtypes = [torch.float32, torch.bfloat16]
311+
dtypes = [
312+
torch.float32,
313+
]
314+
if core.onednn_has_bf16_support():
315+
dtypes.append(torch.bfloat16)
312316
if core.onednn_has_fp16_support():
313317
dtypes.append(torch.float16)
314318
with torch.no_grad():
@@ -353,7 +357,11 @@ def test_linearfusion_args1(self):
353357
ipex.llm.modules.LinearAdd,
354358
ipex.llm.modules.LinearSiluMul,
355359
]
356-
dtypes = [torch.float32, torch.bfloat16]
360+
dtypes = [
361+
torch.float32,
362+
]
363+
if core.onednn_has_bf16_support():
364+
dtypes.append(torch.bfloat16)
357365
if core.onednn_has_fp16_support():
358366
dtypes.append(torch.float16)
359367
with torch.no_grad():
@@ -390,7 +398,11 @@ def test_linearfusion_args2(self):
390398
x2 = copy.deepcopy(x1)
391399
ref_scope = [Linear_add_add]
392400
ipex_scope = [ipex.llm.modules.LinearAddAdd]
393-
dtypes = [torch.float32, torch.bfloat16]
401+
dtypes = [
402+
torch.float32,
403+
]
404+
if core.onednn_has_bf16_support():
405+
dtypes.append(torch.bfloat16)
394406
if core.onednn_has_fp16_support():
395407
dtypes.append(torch.float16)
396408
with torch.no_grad():

0 commit comments

Comments
 (0)