intel-extension-for-pytorch
57 строк · 1.7 Кб
1import torch
2import intel_extension_for_pytorch as ipex
3from functools import wraps
4
5
6class AutoMixPrecision(object):
7def __init__(self, enable_or_not=False, train=False):
8self.old_value = ipex.get_auto_mix_precision()
9self.train_old_value = ipex.get_train()
10self.enable_or_not = enable_or_not
11self.train = train
12
13def __enter__(self):
14if self.enable_or_not:
15ipex.enable_auto_mixed_precision(
16mixed_dtype=torch.bfloat16, train=self.train
17)
18else:
19ipex.enable_auto_mixed_precision(mixed_dtype=None)
20
21def __exit__(self, *args, **kwargs):
22if self.old_value:
23ipex.enable_auto_mixed_precision(
24mixed_dtype=torch.bfloat16, train=self.train_old_value
25)
26else:
27ipex.enable_auto_mixed_precision(mixed_dtype=None)
28
29
30class AutoDNNL(object):
31def __init__(self, enable_or_not=False):
32self.old_value = ipex._get_auto_optimization()
33self.enable_or_not = enable_or_not
34
35def __enter__(self):
36if self.enable_or_not:
37ipex.core.enable_auto_dnnl()
38else:
39ipex.core.disable_auto_dnnl()
40
41def __exit__(self, *args, **kwargs):
42if self.old_value:
43ipex.core.enable_auto_dnnl()
44else:
45ipex.core.disable_auto_dnnl()
46
47
48def runtime_thread_affinity_test_env(func):
49@wraps(func)
50def wrapTheFunction(*args):
51# In some cases, the affinity of main thread may be changed: MultiStreamModule of stream 1
52# Ensure, we restore the affinity of main thread
53previous_cpu_pool = ipex._C.get_current_cpu_pool()
54func(*args)
55ipex._C.set_cpu_pool(previous_cpu_pool)
56
57return wrapTheFunction
58