intel-extension-for-pytorch
730 строк · 30.7 Кб
1# This Python file uses the following encoding: utf-8
2import copy
3import warnings
4
5import torch
6import torch._dynamo
7import torch.fx.experimental.optimization as optimization
8from enum import IntFlag, IntEnum
9
10from .nn import utils
11from .optim._optimizer_utils import (
12optimizer_fusion,
13IPEX_FUSED_OPTIMIZER_LIST_CPU,
14IPEX_FUSED_OPTIMIZER_LIST_XPU,
15)
16from .utils.channels_last_1d import to_channels_last_1d
17from .cpu.utils.linear_bn_folding import linear_bn_fuse
18from .cpu.graph_capture import GraphCapture
19from .nn.utils._lstm_convert import _LSTM, replace_lstm_with_ipex_lstm
20from .nn.utils._weight_prepack import (
21_IPEXConv1d,
22_IPEXConv2d,
23_IPEXConv3d,
24_IPEXConvTranspose2d,
25_IPEXConvTranspose3d,
26_IPEXLinear,
27)
28from .nn.utils._weight_prepack import (
29weight_prepack_with_ipex,
30record_input_shape_for_prepack,
31)
32from .cpu._auto_kernel_selection import (
33_enable_dnnl,
34_disable_dnnl,
35)
36from .fx.concat_linear import _concat_linear
37
38import intel_extension_for_pytorch._C as core
39
40
41def _copy_model_and_optimizer(model, optimizer):
42new_model = copy.deepcopy(model)
43if optimizer is None:
44return new_model, optimizer
45else:
46new_optimizer = copy.deepcopy(optimizer)
47dic_param = {}
48dic_param_for_master_case = {}
49for k, value in zip(model.parameters(), new_model.parameters()):
50dic_param[k] = value
51if hasattr(optimizer, "params_attr"):
52params_attr = optimizer.params_attr
53param_key_pair = {}
54if len(params_attr) != 0:
55new_params_attr = copy.deepcopy(params_attr)
56for (k1, v1), (k2, v2) in zip(
57params_attr.items(), new_params_attr.items()
58):
59if v1.master_parameter is None:
60v2.parameter = dic_param[v1.parameter]
61else:
62dic_param_for_master_case[k1] = k2
63param_key_pair[k1] = k2
64if len(dic_param_for_master_case) != 0:
65dic_param = dic_param_for_master_case
66for k, v in param_key_pair.items():
67new_params_attr[dic_param[k]] = new_params_attr.pop(v)
68setattr(new_optimizer, "params_attr", new_params_attr) # noqa: B010
69
70new_optimizer.state.clear()
71# deep copy param_groups
72for group1, group2 in zip(optimizer.param_groups, new_optimizer.param_groups):
73for i, p in enumerate(group1["params"]):
74if p in dic_param:
75new_model_param = dic_param[p]
76group2["params"][i] = new_model_param
77new_optimizer.state[new_model_param] = copy.deepcopy(
78optimizer.state[p]
79)
80
81def _attach_master_weight_split_attr(old_module, new_module):
82if hasattr(old_module, "master_weight_split"):
83setattr( # noqa: B010
84new_module, "master_weight_split", old_module.master_weight_split
85)
86for (_, old_child), (_, new_child) in zip(
87old_module.named_children(), new_module.named_children()
88):
89_attach_master_weight_split_attr(old_child, new_child)
90
91_attach_master_weight_split_attr(model, new_model)
92return new_model, new_optimizer
93
94
95class auto_channels_last_flag(IntFlag):
96AUTO = -1
97DISABLE = 0
98ENABLE = 1
99
100
101auto_channels_last = auto_channels_last_flag.AUTO
102
103
104def enable_auto_channels_last():
105global auto_channels_last
106auto_channels_last = auto_channels_last_flag.ENABLE
107
108
109def disable_auto_channels_last():
110global auto_channels_last
111auto_channels_last = auto_channels_last_flag.DISABLE
112
113
114class _Properties(object):
115r"""
116This class is to establish a set of default properties.
117
118"""
119
120def __init__(self):
121self.opt_level = None
122self.conv_bn_folding = None
123self.weights_prepack = None
124self.remove_dropout = None
125# optimizer opt conig
126self.split_master_weight_for_bf16 = None
127self.fuse_update_step = None
128self.auto_kernel_selection = None
129self.graph_mode = None
130
131
132# O0 properties
133class _O0:
134def __call__(self, properties):
135properties.opt_level = "O0"
136properties.conv_bn_folding = False
137properties.linear_bn_folding = False
138properties.weights_prepack = False
139properties.replace_dropout_with_identity = False
140properties.optimize_lstm = False
141properties.split_master_weight_for_bf16 = False
142properties.fuse_update_step = False
143properties.auto_kernel_selection = False
144properties.graph_mode = False
145properties.concat_linear = False
146return properties
147
148
149# O1 properties
150class _O1:
151def __call__(self, properties):
152properties.opt_level = "O1"
153properties.conv_bn_folding = True
154properties.linear_bn_folding = True
155properties.weights_prepack = True
156properties.replace_dropout_with_identity = True
157properties.optimize_lstm = True
158properties.split_master_weight_for_bf16 = True
159properties.fuse_update_step = True
160properties.auto_kernel_selection = False
161properties.graph_mode = False
162properties.concat_linear = False
163return properties
164
165
166opt_levels = {"O0": _O0(), "O1": _O1()}
167
168
169def optimize(
170model,
171dtype=None,
172optimizer=None,
173level="O1",
174inplace=False,
175conv_bn_folding=None,
176linear_bn_folding=None,
177weights_prepack=None,
178replace_dropout_with_identity=None,
179optimize_lstm=None,
180split_master_weight_for_bf16=None,
181fuse_update_step=None,
182auto_kernel_selection=None,
183sample_input=None,
184graph_mode=None,
185concat_linear=None,
186):
187r"""
188Apply optimizations at Python frontend to the given model (nn.Module), as
189well as the given optimizer (optional). If the optimizer is given,
190optimizations will be applied for training. Otherwise, optimization will be
191applied for inference. Optimizations include ``conv+bn`` folding (for
192inference only), weight prepacking and so on.
193
194Weight prepacking is a technique to accelerate performance of oneDNN
195operators. In order to achieve better vectorization and cache reuse, onednn
196uses a specific memory layout called ``blocked layout``. Although the
197calculation itself with ``blocked layout`` is fast enough, from memory usage
198perspective it has drawbacks. Running with the ``blocked layout``, oneDNN
199splits one or several dimensions of data into blocks with fixed size each
200time the operator is executed. More details information about oneDNN data
201mermory format is available at `oneDNN manual
202<https://oneapi-src.github.io/oneDNN/dev_guide_understanding_memory_formats.html>`_.
203To reduce this overhead, data will be converted to predefined block shapes
204prior to the execution of oneDNN operator execution. In runtime, if the data
205shape matches oneDNN operator execution requirements, oneDNN won't perform
206memory layout conversion but directly go to calculation. Through this
207methodology, called ``weight prepacking``, it is possible to avoid runtime
208weight data format convertion and thus increase performance.
209
210Args:
211model (torch.nn.Module): User model to apply optimizations on.
212dtype (torch.dtype): Only works for ``torch.bfloat16`` and ``torch.half`` a.k.a ``torch.float16``.
213Model parameters will be casted to ``torch.bfloat16`` or ``torch.half``
214according to dtype of settings. The default value is None, meaning do nothing.
215Note: Data type conversion is only applied to ``nn.Conv2d``, ``nn.Linear``
216and ``nn.ConvTranspose2d`` for both training and inference cases. For
217inference mode, additional data type conversion is applied to the weights
218of ``nn.Embedding`` and ``nn.LSTM``.
219optimizer (torch.optim.Optimizer): User optimizer to apply optimizations
220on, such as SGD. The default value is ``None``, meaning inference case.
221level (string): ``"O0"`` or ``"O1"``. No optimizations are applied with
222``"O0"``. The optimizer function just returns the original model and
223optimizer. With ``"O1"``, the following optimizations are applied:
224conv+bn folding, weights prepack, dropout removal (inferenc model),
225master weight split and fused optimizer update step (training model).
226The optimization options can be further overridden by setting the
227following options explicitly. The default value is ``"O1"``.
228inplace (bool): Whether to perform inplace optimization. Default value is
229``False``.
230conv_bn_folding (bool): Whether to perform ``conv_bn`` folding. It only
231works for inference model. The default value is ``None``. Explicitly
232setting this knob overwrites the configuration set by ``level`` knob.
233linear_bn_folding (bool): Whether to perform ``linear_bn`` folding. It only
234works for inference model. The default value is ``None``. Explicitly
235setting this knob overwrites the configuration set by ``level`` knob.
236weights_prepack (bool): Whether to perform weight prepack for convolution
237and linear to avoid oneDNN weights reorder. The default value is
238``None``. Explicitly setting this knob overwrites the configuration
239set by ``level`` knob. For now, XPU doesn't support weights prepack.
240replace_dropout_with_identity (bool): Whether to replace ``nn.Dropout``
241with ``nn.Identity``. If replaced, the ``aten::dropout`` won't be
242included in the JIT graph. This may provide more fusion opportunites
243on the graph. This only works for inference model. The default value
244is ``None``. Explicitly setting this knob overwrites the configuration
245set by ``level`` knob.
246optimize_lstm (bool): Whether to replace ``nn.LSTM`` with ``IPEX LSTM``
247which takes advantage of oneDNN kernels to get better performance.
248The default value is ``None``. Explicitly setting this knob
249overwrites the configuration set by ``level`` knob.
250split_master_weight_for_bf16 (bool): Whether to split master weights
251update for BF16 training. This saves memory comparing to master
252weight update solution. Split master weights update methodology
253doesn't support all optimizers. The default value is None. The
254default value is ``None``. Explicitly setting this knob overwrites
255the configuration set by ``level`` knob.
256fuse_update_step (bool): Whether to use fused params update for training
257which have better performance. It doesn't support all optimizers.
258The default value is ``None``. Explicitly setting this knob
259overwrites the configuration set by ``level`` knob.
260sample_input (tuple or torch.Tensor): Whether to feed sample input data to ipex.optimize. The shape of
261input data will impact the block format of packed weight. If not feed a sample
262input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics.
263If feed a sample input with real input shape, Intel® Extension for PyTorch* can get
264best block format.
265auto_kernel_selection (bool) [prototype]: Different backends may have
266different performances with different dtypes/shapes. Default value
267is False. Intel® Extension for PyTorch* will try to optimize the
268kernel selection for better performance if this knob is set to
269``True``. You might get better performance at the cost of extra memory usage.
270The default value is ``None``. Explicitly setting this knob overwrites the
271configuration set by ``level`` knob.
272graph_mode: (bool) [prototype]: It will automatically apply a combination of methods
273to generate graph or multiple subgraphs if True. The default value is ``False``.
274concat_linear (bool): Whether to perform ``concat_linear``. It only
275works for inference model. The default value is ``None``. Explicitly
276setting this knob overwrites the configuration set by ``level`` knob.
277
278Returns:
279Model and optimizer (if given) modified according to the ``level`` knob
280or other user settings. ``conv+bn`` folding may take place and
281``dropout`` may be replaced by ``identity``. In inference scenarios,
282convolutuon, linear and lstm will be replaced with the optimized
283counterparts in Intel® Extension for PyTorch* (weight prepack for
284convolution and linear) for good performance. In bfloat16 or float16 scenarios,
285parameters of convolution and linear will be casted to bfloat16 or float16 dtype.
286
287.. warning::
288
289Please invoke ``optimize`` function BEFORE invoking DDP in distributed
290training scenario.
291
292The ``optimize`` function deepcopys the original model. If DDP is invoked
293before ``optimize`` function, DDP is applied on the origin model, rather
294than the one returned from ``optimize`` function. In this case, some
295operators in DDP, like allreduce, will not be invoked and thus may cause
296unpredictable accuracy loss.
297
298Examples:
299
300>>> # bfloat16 inference case.
301>>> model = ...
302>>> model.load_state_dict(torch.load(PATH))
303>>> model.eval()
304>>> optimized_model = ipex.optimize(model, dtype=torch.bfloat16)
305>>> # running evaluation step.
306>>> # bfloat16 training case.
307>>> optimizer = ...
308>>> model.train()
309>>> optimized_model, optimized_optimizer = ipex.optimize(model, dtype=torch.bfloat16, optimizer=optimizer)
310>>> # running training step.
311
312`torch.xpu.optimize()` is an alternative of optimize API in Intel® Extension for PyTorch*,
313to provide identical usage for XPU device only. The motivation of adding this alias is
314to unify the coding style in user scripts base on torch.xpu modular.
315
316Examples:
317
318>>> # bfloat16 inference case.
319>>> model = ...
320>>> model.load_state_dict(torch.load(PATH))
321>>> model.eval()
322>>> optimized_model = torch.xpu.optimize(model, dtype=torch.bfloat16)
323>>> # running evaluation step.
324>>> # bfloat16 training case.
325>>> optimizer = ...
326>>> model.train()
327>>> optimized_model, optimized_optimizer = torch.xpu.optimize(model, dtype=torch.bfloat16, optimizer=optimizer)
328>>> # running training step.
329
330"""
331if isinstance(model, torch.jit.ScriptModule):
332if optimizer is None:
333return model
334return model, optimizer
335
336if model.training:
337assert optimizer is not None, "The optimizer should be given for training mode"
338else:
339assert optimizer is None, "The optimizer should not be given for inference mode"
340
341opt_properties = _Properties()
342if level not in opt_levels:
343raise RuntimeError(
344f"Unexpected optimization level {level}. Options are 'O0', 'O1'."
345)
346else:
347opt_properties = opt_levels[level](opt_properties)
348
349device_type = "cpu"
350model_parameters_list = list(model.parameters())
351if len(model_parameters_list) and model_parameters_list[0].device.type == "xpu":
352if not all([param.device.type == "xpu" for param in model_parameters_list]):
353raise RuntimeError("The model is mixed with different device type")
354else:
355device_type = "xpu"
356
357global auto_channels_last
358
359def xpu_check_channel_last():
360global auto_channels_last
361if auto_channels_last.value == auto_channels_last_flag.ENABLE:
362return True
363elif (
364auto_channels_last.value == auto_channels_last_flag.AUTO
365and torch.xpu.has_2d_block_array()
366):
367return True
368else:
369return False
370
371if device_type == "cpu" and (
372auto_channels_last.value != auto_channels_last_flag.DISABLE
373):
374_convert_convNd_deconvNd_weight_memory_format(model)
375elif device_type == "xpu" and xpu_check_channel_last():
376_convert_convNd_deconvNd_weight_memory_format(model)
377
378if level is not None:
379opt_properties.opt_level = level
380if conv_bn_folding is not None:
381opt_properties.conv_bn_folding = conv_bn_folding
382if linear_bn_folding is not None:
383opt_properties.linear_bn_folding = linear_bn_folding
384if weights_prepack is not None:
385opt_properties.weights_prepack = weights_prepack
386if replace_dropout_with_identity is not None:
387opt_properties.replace_dropout_with_identity = replace_dropout_with_identity
388if optimize_lstm is not None:
389opt_properties.optimize_lstm = optimize_lstm
390if split_master_weight_for_bf16 is not None:
391opt_properties.split_master_weight_for_bf16 = split_master_weight_for_bf16
392if fuse_update_step is not None:
393opt_properties.fuse_update_step = fuse_update_step
394if auto_kernel_selection is not None:
395opt_properties.auto_kernel_selection = auto_kernel_selection
396if graph_mode is not None:
397opt_properties.graph_mode = graph_mode
398if concat_linear is not None:
399opt_properties.concat_linear = concat_linear
400
401_disable_dnnl()
402if opt_properties.auto_kernel_selection:
403_enable_dnnl()
404
405# when on xpu, some features are not supported
406if device_type == "xpu":
407if opt_properties.auto_kernel_selection:
408warnings.warn(
409"For XPU device, the auto kernel selection is unsupported, so disable it."
410)
411opt_properties.auto_kernel_selection = False
412if opt_properties.split_master_weight_for_bf16:
413# currently split master weight for xpu only support sgd
414if type(optimizer) is torch.optim.SGD:
415opt_properties.split_master_weight_for_bf16 = True
416else:
417opt_properties.split_master_weight_for_bf16 = False
418if opt_properties.graph_mode:
419warnings.warn(
420"For XPU, the oob solution for inference is to trace model outside of the torch.xpu.optimize,"
421+ " so temp to disable the graph mode"
422)
423opt_properties.graph_mode = False
424if not inplace:
425warnings.warn(
426"For XPU device to save valuable device memory, temp to do optimization on inplaced model,"
427+ " so make inplace to be true"
428)
429inplace = True
430# for XPU, weight prepack is unsupported, so sample input is useless
431if opt_properties.weights_prepack:
432warnings.warn(
433"For XPU, the weight prepack and sample input are disabled. The onednn layout"
434+ " is automatically chosen to use"
435)
436opt_properties.weights_prepack = False
437sample_input = None
438if opt_properties.optimize_lstm is not None:
439warnings.warn(
440"For XPU, the optimize_lstm(replace lstm with ipex_lstm) is unsupported, so disable it"
441)
442opt_properties.optimize_lstm = False
443
444if inplace:
445optimized_model = model
446optimized_optimizer = optimizer
447else:
448optimized_model, optimized_optimizer = _copy_model_and_optimizer(
449model, optimizer
450)
451
452if sample_input is not None:
453if isinstance(sample_input, torch.Tensor):
454sample_input = (sample_input,)
455record_input_shape_for_prepack(optimized_model, sample_input)
456params_attr = {}
457if not model.training:
458if opt_properties.conv_bn_folding:
459try:
460optimized_model = optimization.fuse(optimized_model, inplace=True)
461except: # noqa E722
462warnings.warn(
463"Conv BatchNorm folding failed during the optimize process."
464)
465if opt_properties.linear_bn_folding:
466try:
467optimized_model = linear_bn_fuse(optimized_model, inplace=True)
468except BaseException:
469warnings.warn(
470"Linear BatchNorm folding failed during the optimize process."
471)
472if opt_properties.replace_dropout_with_identity:
473utils._model_convert.replace_dropout_with_identity(optimized_model)
474if opt_properties.concat_linear:
475optimized_model = _concat_linear(optimized_model, inplace=True)
476if dtype in (
477torch.bfloat16,
478torch.float16,
479):
480params_attr, optimized_model = utils._model_convert.convert_model_data_type(
481optimized_model, dtype
482)
483
484if opt_properties.optimize_lstm:
485replace_lstm_with_ipex_lstm(optimized_model, optimized_optimizer)
486torch._dynamo.allow_in_graph(_LSTM)
487
488if (
489model.training
490and opt_properties.split_master_weight_for_bf16
491and dtype is torch.bfloat16
492):
493if not opt_properties.fuse_update_step:
494opt_properties.split_master_weight_for_bf16 = False
495warnings.warn(
496"IPEX does not non-fused split master weight for bf16 training, "
497+ "have reset split_master_weight_for_bf16 flag to False. "
498+ "If you want to use split_master_weight_for_bf16. "
499+ "Please set both split_master_weight_for_bf16 and fuse_update_step to True."
500)
501elif (
502type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST_CPU
503and device_type == "cpu"
504):
505opt_properties.split_master_weight_for_bf16 = False
506opt_properties.fuse_update_step = False
507warnings.warn(
508"IPEX CPU does not support fused/fused split update for "
509+ str(type(optimizer))
510+ " will use non-fused master weight update for bf16 training on CPU."
511)
512elif (
513type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST_XPU
514and device_type == "xpu"
515):
516opt_properties.split_master_weight_for_bf16 = False
517opt_properties.fuse_update_step = False
518warnings.warn(
519"IPEX XPU does not support fused/fused split update for "
520+ str(type(optimizer))
521+ " will use non-fused master weight update for bf16 training on XPU."
522)
523
524if model.training:
525if hasattr(optimized_optimizer, "params_attr"):
526params_attr = optimized_optimizer.params_attr
527if dtype == torch.float16:
528assert (
529device_type != "xpu"
530), "For now, XPU device does not support model training with half precision."
531opt_properties.split_master_weight_for_bf16 = False
532if dtype in (torch.bfloat16, torch.float16):
533# convert optimizer for training case.
534(
535optimized_model,
536optimized_optimizer,
537params_attr,
538) = utils._weight_cast.weight_dtype_convert_with_ipex(
539optimized_model,
540optimized_optimizer,
541params_attr,
542opt_properties.split_master_weight_for_bf16,
543dtype,
544)
545
546# Since TorchDynamo cannot handle custom operations yet, for the case of inference graph mode,
547# the weights prepacking here is temporarily cancelled, and it will be completed on the graph.
548if opt_properties.weights_prepack and device_type == "cpu":
549if dtype == torch.bfloat16:
550assert core.onednn_has_bf16_support(), (
551"BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq, "
552+ "but the desired instruction sets are not available. "
553+ "Please set dtype to torch.float or set weights_prepack to False."
554)
555if dtype == torch.half:
556assert core.onednn_has_fp16_support(), (
557"FP16 weight prepack needs the cpu support avx512_core_fp16, "
558+ "but the desired instruction sets are not available. "
559+ "Please set dtype to torch.float or set weights_prepack to False."
560)
561(
562optimized_model,
563optimized_optimizer,
564params_attr,
565) = weight_prepack_with_ipex(
566optimized_model, optimized_optimizer, params_attr, "cpu"
567)
568torch._dynamo.allow_in_graph(_IPEXConv1d)
569torch._dynamo.allow_in_graph(_IPEXConv2d)
570torch._dynamo.allow_in_graph(_IPEXConv3d)
571torch._dynamo.allow_in_graph(_IPEXConvTranspose2d)
572torch._dynamo.allow_in_graph(_IPEXConvTranspose3d)
573torch._dynamo.allow_in_graph(_IPEXLinear)
574
575if opt_properties.graph_mode:
576_old_forward = optimized_model.forward
577wrapper = GraphCapture(
578optimized_model,
579optimizer is not None,
580dtype,
581opt_properties.weights_prepack,
582)
583optimized_model.forward = wrapper(_old_forward)
584
585if optimizer is None:
586return optimized_model
587
588# with an optimizer
589if opt_properties.fuse_update_step:
590optimized_optimizer = optimizer_fusion(
591optimized_optimizer,
592opt_properties.split_master_weight_for_bf16,
593device_type,
594)
595return optimized_model, optimized_optimizer
596
597
598def _convert_convNd_deconvNd_weight_memory_format(module):
599# inspired from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/memory_format.py
600if isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
601weight_data = to_channels_last_1d(module.weight.detach().clone())
602module.weight.data = weight_data.resize_(weight_data.size())
603elif isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
604weight_data = (
605module.weight.detach().clone().contiguous(memory_format=torch.channels_last)
606)
607module.weight.data = weight_data.resize_(
608weight_data.size(), memory_format=torch.channels_last
609)
610elif isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
611weight_data = (
612module.weight.detach()
613.clone()
614.contiguous(memory_format=torch.channels_last_3d)
615)
616module.weight.data = weight_data.resize_(
617weight_data.size(), memory_format=torch.channels_last_3d
618)
619
620for child in module.children():
621_convert_convNd_deconvNd_weight_memory_format(child)
622
623
624class FP32MathMode(IntEnum):
625FP32 = int(core.FP32MathMode.FP32)
626TF32 = int(core.FP32MathMode.TF32)
627BF32 = int(core.FP32MathMode.BF32)
628
629
630def set_fp32_math_mode(mode=FP32MathMode.FP32, device="cpu"):
631r"""
632Enable or disable implicit data type conversion.
633
634Args:
635mode (FP32MathMode): ``FP32MathMode.FP32``, ``FP32MathMode.BF32`` or
636``FP32MathMode.TF32`` (GPU ONLY). oneDNN fpmath mode will be disabled by default if dtype
637is set to ``FP32MathMode.FP32``. The implicit ``FP32`` to ``TF32`` data type conversion
638will be enabled if dtype is set to ``FP32MathMode.TF32``. The implicit ``FP32``
639to ``BF16`` data type conversion will be enabled if dtype is set to ``FP32MathMode.BF32``.
640device (string): ``cpu``, ``xpu``
641
642Examples:
643
644>>> import intel_extension_for_pytorch as ipex
645>>> # to enable the implicit data type conversion
646>>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32)
647>>> # to disable the implicit data type conversion
648>>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32)
649
650``torch.xpu.set_fp32_math_mode()`` is an alternative function in Intel® Extension for PyTorch*,
651to provide identical usage for XPU device only. The motivation of adding this alias is
652to unify the coding style in user scripts base on ``torch.xpu`` modular.
653
654Examples:
655
656>>> import intel_extension_for_pytorch as ipex
657>>> # to enable the implicit data type conversion
658>>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32)
659>>> # to disable the implicit data type conversion
660>>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32)
661"""
662
663if device == "cpu":
664if mode == FP32MathMode.BF32:
665core.set_fp32_math_mode(core.FP32MathMode.BF32)
666elif mode == FP32MathMode.FP32:
667core.set_fp32_math_mode(core.FP32MathMode.FP32)
668else:
669warnings.warn(
670"For CPU device, IPEX does not support mode except \
671FP32MathMode.FP32 and FP32MathMode.BF32 for fpmath_mode right now."
672)
673elif device == "xpu":
674if mode == FP32MathMode.BF32:
675torch.xpu.set_fp32_math_mode(torch.xpu.FP32MathMode.BF32)
676elif mode == FP32MathMode.FP32:
677torch.xpu.set_fp32_math_mode(torch.xpu.FP32MathMode.FP32)
678elif mode == FP32MathMode.TF32:
679torch.xpu.set_fp32_math_mode(torch.xpu.FP32MathMode.TF32)
680else:
681warnings.warn(
682"For XPU device, IPEX does not support mode except \
683FP32MathMode.FP32, FP32MathMode.BF32 and FP32MathMode.TF32 for fpmath_mode right now."
684)
685else:
686raise RuntimeError(
687"Unexpected device type {}. ".format(device) + "Supported are 'cpu', 'xpu'."
688)
689
690
691def get_fp32_math_mode(device="cpu"):
692r"""
693Get the current fpmath_mode setting.
694
695Args:
696device (string): ``cpu``, ``xpu``
697
698Returns:
699Fpmath mode
700The value will be ``FP32MathMode.FP32``, ``FP32MathMode.BF32`` or ``FP32MathMode.TF32`` (GPU ONLY).
701oneDNN fpmath mode will be disabled by default if dtype is set to ``FP32MathMode.FP32``.
702The implicit ``FP32`` to ``TF32`` data type conversion will be enabled if dtype is set
703to ``FP32MathMode.TF32``. The implicit ``FP32`` to ``BF16`` data type conversion will be
704enabled if dtype is set to ``FP32MathMode.BF32``.
705
706Examples:
707
708>>> import intel_extension_for_pytorch as ipex
709>>> # to get the current fpmath mode
710>>> ipex.get_fp32_math_mode(device="xpu")
711
712``torch.xpu.get_fp32_math_mode()`` is an alternative function in Intel® Extension for PyTorch*,
713to provide identical usage for XPU device only. The motivation of adding this alias is
714to unify the coding style in user scripts base on ``torch.xpu`` modular.
715
716Examples:
717
718>>> import intel_extension_for_pytorch as ipex
719>>> # to get the current fpmath mode
720>>> torch.xpu.get_fp32_math_mode(device="xpu")
721"""
722
723if device == "cpu":
724return core.get_fp32_math_mode()
725elif device == "xpu":
726return torch.xpu.get_fp32_math_mode()
727else:
728raise RuntimeError(
729"Unexpected device type {}. ".format(device) + "Supported are 'cpu', 'xpu'."
730)
731