intel-extension-for-pytorch

Форк
0
730 строк · 30.7 Кб
1
# This Python file uses the following encoding: utf-8
2
import copy
3
import warnings
4

5
import torch
6
import torch._dynamo
7
import torch.fx.experimental.optimization as optimization
8
from enum import IntFlag, IntEnum
9

10
from .nn import utils
11
from .optim._optimizer_utils import (
12
    optimizer_fusion,
13
    IPEX_FUSED_OPTIMIZER_LIST_CPU,
14
    IPEX_FUSED_OPTIMIZER_LIST_XPU,
15
)
16
from .utils.channels_last_1d import to_channels_last_1d
17
from .cpu.utils.linear_bn_folding import linear_bn_fuse
18
from .cpu.graph_capture import GraphCapture
19
from .nn.utils._lstm_convert import _LSTM, replace_lstm_with_ipex_lstm
20
from .nn.utils._weight_prepack import (
21
    _IPEXConv1d,
22
    _IPEXConv2d,
23
    _IPEXConv3d,
24
    _IPEXConvTranspose2d,
25
    _IPEXConvTranspose3d,
26
    _IPEXLinear,
27
)
28
from .nn.utils._weight_prepack import (
29
    weight_prepack_with_ipex,
30
    record_input_shape_for_prepack,
31
)
32
from .cpu._auto_kernel_selection import (
33
    _enable_dnnl,
34
    _disable_dnnl,
35
)
36
from .fx.concat_linear import _concat_linear
37

38
import intel_extension_for_pytorch._C as core
39

40

41
def _copy_model_and_optimizer(model, optimizer):
42
    new_model = copy.deepcopy(model)
43
    if optimizer is None:
44
        return new_model, optimizer
45
    else:
46
        new_optimizer = copy.deepcopy(optimizer)
47
        dic_param = {}
48
        dic_param_for_master_case = {}
49
        for k, value in zip(model.parameters(), new_model.parameters()):
50
            dic_param[k] = value
51
        if hasattr(optimizer, "params_attr"):
52
            params_attr = optimizer.params_attr
53
            param_key_pair = {}
54
            if len(params_attr) != 0:
55
                new_params_attr = copy.deepcopy(params_attr)
56
                for (k1, v1), (k2, v2) in zip(
57
                    params_attr.items(), new_params_attr.items()
58
                ):
59
                    if v1.master_parameter is None:
60
                        v2.parameter = dic_param[v1.parameter]
61
                    else:
62
                        dic_param_for_master_case[k1] = k2
63
                    param_key_pair[k1] = k2
64
                if len(dic_param_for_master_case) != 0:
65
                    dic_param = dic_param_for_master_case
66
                for k, v in param_key_pair.items():
67
                    new_params_attr[dic_param[k]] = new_params_attr.pop(v)
68
                setattr(new_optimizer, "params_attr", new_params_attr)  # noqa: B010
69

70
        new_optimizer.state.clear()
71
        # deep copy param_groups
72
        for group1, group2 in zip(optimizer.param_groups, new_optimizer.param_groups):
73
            for i, p in enumerate(group1["params"]):
74
                if p in dic_param:
75
                    new_model_param = dic_param[p]
76
                    group2["params"][i] = new_model_param
77
                    new_optimizer.state[new_model_param] = copy.deepcopy(
78
                        optimizer.state[p]
79
                    )
80

81
        def _attach_master_weight_split_attr(old_module, new_module):
82
            if hasattr(old_module, "master_weight_split"):
83
                setattr(  # noqa: B010
84
                    new_module, "master_weight_split", old_module.master_weight_split
85
                )
86
            for (_, old_child), (_, new_child) in zip(
87
                old_module.named_children(), new_module.named_children()
88
            ):
89
                _attach_master_weight_split_attr(old_child, new_child)
90

91
        _attach_master_weight_split_attr(model, new_model)
92
        return new_model, new_optimizer
93

94

95
class auto_channels_last_flag(IntFlag):
96
    AUTO = -1
97
    DISABLE = 0
98
    ENABLE = 1
99

100

101
auto_channels_last = auto_channels_last_flag.AUTO
102

103

104
def enable_auto_channels_last():
105
    global auto_channels_last
106
    auto_channels_last = auto_channels_last_flag.ENABLE
107

108

109
def disable_auto_channels_last():
110
    global auto_channels_last
111
    auto_channels_last = auto_channels_last_flag.DISABLE
112

113

114
class _Properties(object):
115
    r"""
116
    This class is to establish a set of default properties.
117

118
    """
119

120
    def __init__(self):
121
        self.opt_level = None
122
        self.conv_bn_folding = None
123
        self.weights_prepack = None
124
        self.remove_dropout = None
125
        # optimizer opt conig
126
        self.split_master_weight_for_bf16 = None
127
        self.fuse_update_step = None
128
        self.auto_kernel_selection = None
129
        self.graph_mode = None
130

131

132
# O0 properties
133
class _O0:
134
    def __call__(self, properties):
135
        properties.opt_level = "O0"
136
        properties.conv_bn_folding = False
137
        properties.linear_bn_folding = False
138
        properties.weights_prepack = False
139
        properties.replace_dropout_with_identity = False
140
        properties.optimize_lstm = False
141
        properties.split_master_weight_for_bf16 = False
142
        properties.fuse_update_step = False
143
        properties.auto_kernel_selection = False
144
        properties.graph_mode = False
145
        properties.concat_linear = False
146
        return properties
147

148

149
# O1 properties
150
class _O1:
151
    def __call__(self, properties):
152
        properties.opt_level = "O1"
153
        properties.conv_bn_folding = True
154
        properties.linear_bn_folding = True
155
        properties.weights_prepack = True
156
        properties.replace_dropout_with_identity = True
157
        properties.optimize_lstm = True
158
        properties.split_master_weight_for_bf16 = True
159
        properties.fuse_update_step = True
160
        properties.auto_kernel_selection = False
161
        properties.graph_mode = False
162
        properties.concat_linear = False
163
        return properties
164

165

166
opt_levels = {"O0": _O0(), "O1": _O1()}
167

168

169
def optimize(
170
    model,
171
    dtype=None,
172
    optimizer=None,
173
    level="O1",
174
    inplace=False,
175
    conv_bn_folding=None,
176
    linear_bn_folding=None,
177
    weights_prepack=None,
178
    replace_dropout_with_identity=None,
179
    optimize_lstm=None,
180
    split_master_weight_for_bf16=None,
181
    fuse_update_step=None,
182
    auto_kernel_selection=None,
183
    sample_input=None,
184
    graph_mode=None,
185
    concat_linear=None,
186
):
187
    r"""
188
    Apply optimizations at Python frontend to the given model (nn.Module), as
189
    well as the given optimizer (optional). If the optimizer is given,
190
    optimizations will be applied for training. Otherwise, optimization will be
191
    applied for inference. Optimizations include ``conv+bn`` folding (for
192
    inference only), weight prepacking and so on.
193

194
    Weight prepacking is a technique to accelerate performance of oneDNN
195
    operators. In order to achieve better vectorization and cache reuse, onednn
196
    uses a specific memory layout called ``blocked layout``. Although the
197
    calculation itself with ``blocked layout`` is fast enough, from memory usage
198
    perspective it has drawbacks. Running with the ``blocked layout``, oneDNN
199
    splits one or several dimensions of data into blocks with fixed size each
200
    time the operator is executed. More details information about oneDNN data
201
    mermory format is available at `oneDNN manual
202
    <https://oneapi-src.github.io/oneDNN/dev_guide_understanding_memory_formats.html>`_.
203
    To reduce this overhead, data will be converted to predefined block shapes
204
    prior to the execution of oneDNN operator execution. In runtime, if the data
205
    shape matches oneDNN operator execution requirements, oneDNN won't perform
206
    memory layout conversion but directly go to calculation. Through this
207
    methodology, called ``weight prepacking``, it is possible to avoid runtime
208
    weight data format convertion and thus increase performance.
209

210
    Args:
211
        model (torch.nn.Module): User model to apply optimizations on.
212
        dtype (torch.dtype): Only works for ``torch.bfloat16`` and ``torch.half`` a.k.a ``torch.float16``.
213
            Model parameters will be casted to ``torch.bfloat16`` or ``torch.half``
214
            according to dtype of settings. The default value is None, meaning do nothing.
215
            Note: Data type conversion is only applied to ``nn.Conv2d``, ``nn.Linear``
216
            and ``nn.ConvTranspose2d`` for both training and inference cases. For
217
            inference mode, additional data type conversion is applied to the weights
218
            of ``nn.Embedding`` and ``nn.LSTM``.
219
        optimizer (torch.optim.Optimizer): User optimizer to apply optimizations
220
            on, such as SGD. The default value is ``None``, meaning inference case.
221
        level (string): ``"O0"`` or ``"O1"``. No optimizations are applied with
222
            ``"O0"``. The optimizer function just returns the original model and
223
            optimizer. With ``"O1"``, the following optimizations are applied:
224
            conv+bn folding, weights prepack, dropout removal (inferenc model),
225
            master weight split and fused optimizer update step (training model).
226
            The optimization options can be further overridden by setting the
227
            following options explicitly. The default value is ``"O1"``.
228
        inplace (bool): Whether to perform inplace optimization. Default value is
229
            ``False``.
230
        conv_bn_folding (bool): Whether to perform ``conv_bn`` folding. It only
231
            works for inference model. The default value is ``None``. Explicitly
232
            setting this knob overwrites the configuration set by ``level`` knob.
233
        linear_bn_folding (bool): Whether to perform ``linear_bn`` folding. It only
234
            works for inference model. The default value is ``None``. Explicitly
235
            setting this knob overwrites the configuration set by ``level`` knob.
236
        weights_prepack (bool): Whether to perform weight prepack for convolution
237
            and linear to avoid oneDNN weights reorder. The default value is
238
            ``None``. Explicitly setting this knob overwrites the configuration
239
            set by ``level`` knob. For now, XPU doesn't support weights prepack.
240
        replace_dropout_with_identity (bool): Whether to replace ``nn.Dropout``
241
            with ``nn.Identity``. If replaced, the ``aten::dropout`` won't be
242
            included in the JIT graph. This may provide more fusion opportunites
243
            on the graph. This only works for inference model. The default value
244
            is ``None``. Explicitly setting this knob overwrites the configuration
245
            set by ``level`` knob.
246
        optimize_lstm (bool): Whether to replace ``nn.LSTM`` with ``IPEX LSTM``
247
            which takes advantage of oneDNN kernels to get better performance.
248
            The default value is ``None``. Explicitly setting this knob
249
            overwrites the configuration set by ``level`` knob.
250
        split_master_weight_for_bf16 (bool): Whether to split master weights
251
            update for BF16 training. This saves memory comparing to master
252
            weight update solution. Split master weights update methodology
253
            doesn't support all optimizers. The default value is None. The
254
            default value is ``None``. Explicitly setting this knob overwrites
255
            the configuration set by ``level`` knob.
256
        fuse_update_step (bool): Whether to use fused params update for training
257
            which have better performance. It doesn't support all optimizers.
258
            The default value is ``None``. Explicitly setting this knob
259
            overwrites the configuration set by ``level`` knob.
260
        sample_input (tuple or torch.Tensor): Whether to feed sample input data to ipex.optimize. The shape of
261
            input data will impact the block format of packed weight. If not feed a sample
262
            input, Intel® Extension for PyTorch* will pack the weight per some predefined heuristics.
263
            If feed a sample input with real input shape, Intel® Extension for PyTorch* can get
264
            best block format.
265
        auto_kernel_selection (bool) [prototype]: Different backends may have
266
            different performances with different dtypes/shapes. Default value
267
            is False. Intel® Extension for PyTorch* will try to optimize the
268
            kernel selection for better performance if this knob is set to
269
            ``True``. You might get better performance at the cost of extra memory usage.
270
            The default value is ``None``. Explicitly setting this knob overwrites the
271
            configuration set by ``level`` knob.
272
        graph_mode: (bool) [prototype]: It will automatically apply a combination of methods
273
            to generate graph or multiple subgraphs if True. The default value is ``False``.
274
        concat_linear (bool): Whether to perform ``concat_linear``. It only
275
            works for inference model. The default value is ``None``. Explicitly
276
            setting this knob overwrites the configuration set by ``level`` knob.
277

278
    Returns:
279
        Model and optimizer (if given) modified according to the ``level`` knob
280
        or other user settings. ``conv+bn`` folding may take place and
281
        ``dropout`` may be replaced by ``identity``. In inference scenarios,
282
        convolutuon, linear and lstm will be replaced with the optimized
283
        counterparts in Intel® Extension for PyTorch* (weight prepack for
284
        convolution and linear) for good performance. In bfloat16 or float16 scenarios,
285
        parameters of convolution and linear will be casted to bfloat16 or float16 dtype.
286

287
    .. warning::
288

289
        Please invoke ``optimize`` function BEFORE invoking DDP in distributed
290
        training scenario.
291

292
        The ``optimize`` function deepcopys the original model. If DDP is invoked
293
        before ``optimize`` function, DDP is applied on the origin model, rather
294
        than the one returned from ``optimize`` function. In this case, some
295
        operators in DDP, like allreduce, will not be invoked and thus may cause
296
        unpredictable accuracy loss.
297

298
    Examples:
299

300
        >>> # bfloat16 inference case.
301
        >>> model = ...
302
        >>> model.load_state_dict(torch.load(PATH))
303
        >>> model.eval()
304
        >>> optimized_model = ipex.optimize(model, dtype=torch.bfloat16)
305
        >>> # running evaluation step.
306
        >>> # bfloat16 training case.
307
        >>> optimizer = ...
308
        >>> model.train()
309
        >>> optimized_model, optimized_optimizer = ipex.optimize(model, dtype=torch.bfloat16, optimizer=optimizer)
310
        >>> # running training step.
311

312
    `torch.xpu.optimize()` is an alternative of optimize API in Intel® Extension for PyTorch*,
313
    to provide identical usage for XPU device only. The motivation of adding this alias is
314
    to unify the coding style in user scripts base on torch.xpu modular.
315

316
    Examples:
317

318
        >>> # bfloat16 inference case.
319
        >>> model = ...
320
        >>> model.load_state_dict(torch.load(PATH))
321
        >>> model.eval()
322
        >>> optimized_model = torch.xpu.optimize(model, dtype=torch.bfloat16)
323
        >>> # running evaluation step.
324
        >>> # bfloat16 training case.
325
        >>> optimizer = ...
326
        >>> model.train()
327
        >>> optimized_model, optimized_optimizer = torch.xpu.optimize(model, dtype=torch.bfloat16, optimizer=optimizer)
328
        >>> # running training step.
329

330
    """
331
    if isinstance(model, torch.jit.ScriptModule):
332
        if optimizer is None:
333
            return model
334
        return model, optimizer
335

336
    if model.training:
337
        assert optimizer is not None, "The optimizer should be given for training mode"
338
    else:
339
        assert optimizer is None, "The optimizer should not be given for inference mode"
340

341
    opt_properties = _Properties()
342
    if level not in opt_levels:
343
        raise RuntimeError(
344
            f"Unexpected optimization level {level}. Options are 'O0', 'O1'."
345
        )
346
    else:
347
        opt_properties = opt_levels[level](opt_properties)
348

349
    device_type = "cpu"
350
    model_parameters_list = list(model.parameters())
351
    if len(model_parameters_list) and model_parameters_list[0].device.type == "xpu":
352
        if not all([param.device.type == "xpu" for param in model_parameters_list]):
353
            raise RuntimeError("The model is mixed with different device type")
354
        else:
355
            device_type = "xpu"
356

357
    global auto_channels_last
358

359
    def xpu_check_channel_last():
360
        global auto_channels_last
361
        if auto_channels_last.value == auto_channels_last_flag.ENABLE:
362
            return True
363
        elif (
364
            auto_channels_last.value == auto_channels_last_flag.AUTO
365
            and torch.xpu.has_2d_block_array()
366
        ):
367
            return True
368
        else:
369
            return False
370

371
    if device_type == "cpu" and (
372
        auto_channels_last.value != auto_channels_last_flag.DISABLE
373
    ):
374
        _convert_convNd_deconvNd_weight_memory_format(model)
375
    elif device_type == "xpu" and xpu_check_channel_last():
376
        _convert_convNd_deconvNd_weight_memory_format(model)
377

378
    if level is not None:
379
        opt_properties.opt_level = level
380
    if conv_bn_folding is not None:
381
        opt_properties.conv_bn_folding = conv_bn_folding
382
    if linear_bn_folding is not None:
383
        opt_properties.linear_bn_folding = linear_bn_folding
384
    if weights_prepack is not None:
385
        opt_properties.weights_prepack = weights_prepack
386
    if replace_dropout_with_identity is not None:
387
        opt_properties.replace_dropout_with_identity = replace_dropout_with_identity
388
    if optimize_lstm is not None:
389
        opt_properties.optimize_lstm = optimize_lstm
390
    if split_master_weight_for_bf16 is not None:
391
        opt_properties.split_master_weight_for_bf16 = split_master_weight_for_bf16
392
    if fuse_update_step is not None:
393
        opt_properties.fuse_update_step = fuse_update_step
394
    if auto_kernel_selection is not None:
395
        opt_properties.auto_kernel_selection = auto_kernel_selection
396
    if graph_mode is not None:
397
        opt_properties.graph_mode = graph_mode
398
    if concat_linear is not None:
399
        opt_properties.concat_linear = concat_linear
400

401
    _disable_dnnl()
402
    if opt_properties.auto_kernel_selection:
403
        _enable_dnnl()
404

405
    # when on xpu, some features are not supported
406
    if device_type == "xpu":
407
        if opt_properties.auto_kernel_selection:
408
            warnings.warn(
409
                "For XPU device, the auto kernel selection is unsupported, so disable it."
410
            )
411
            opt_properties.auto_kernel_selection = False
412
        if opt_properties.split_master_weight_for_bf16:
413
            # currently split master weight for xpu only support sgd
414
            if type(optimizer) is torch.optim.SGD:
415
                opt_properties.split_master_weight_for_bf16 = True
416
            else:
417
                opt_properties.split_master_weight_for_bf16 = False
418
        if opt_properties.graph_mode:
419
            warnings.warn(
420
                "For XPU, the oob solution for inference is to trace model outside of the torch.xpu.optimize,"
421
                + " so temp to disable the graph mode"
422
            )
423
            opt_properties.graph_mode = False
424
        if not inplace:
425
            warnings.warn(
426
                "For XPU device to save valuable device memory, temp to do optimization on inplaced model,"
427
                + " so make inplace to be true"
428
            )
429
            inplace = True
430
        # for XPU, weight prepack is unsupported, so sample input is useless
431
        if opt_properties.weights_prepack:
432
            warnings.warn(
433
                "For XPU, the weight prepack and sample input are disabled. The onednn layout"
434
                + " is automatically chosen to use"
435
            )
436
            opt_properties.weights_prepack = False
437
            sample_input = None
438
        if opt_properties.optimize_lstm is not None:
439
            warnings.warn(
440
                "For XPU, the optimize_lstm(replace lstm with ipex_lstm) is unsupported, so disable it"
441
            )
442
            opt_properties.optimize_lstm = False
443

444
    if inplace:
445
        optimized_model = model
446
        optimized_optimizer = optimizer
447
    else:
448
        optimized_model, optimized_optimizer = _copy_model_and_optimizer(
449
            model, optimizer
450
        )
451

452
    if sample_input is not None:
453
        if isinstance(sample_input, torch.Tensor):
454
            sample_input = (sample_input,)
455
        record_input_shape_for_prepack(optimized_model, sample_input)
456
    params_attr = {}
457
    if not model.training:
458
        if opt_properties.conv_bn_folding:
459
            try:
460
                optimized_model = optimization.fuse(optimized_model, inplace=True)
461
            except:  # noqa E722
462
                warnings.warn(
463
                    "Conv BatchNorm folding failed during the optimize process."
464
                )
465
        if opt_properties.linear_bn_folding:
466
            try:
467
                optimized_model = linear_bn_fuse(optimized_model, inplace=True)
468
            except BaseException:
469
                warnings.warn(
470
                    "Linear BatchNorm folding failed during the optimize process."
471
                )
472
        if opt_properties.replace_dropout_with_identity:
473
            utils._model_convert.replace_dropout_with_identity(optimized_model)
474
        if opt_properties.concat_linear:
475
            optimized_model = _concat_linear(optimized_model, inplace=True)
476
        if dtype in (
477
            torch.bfloat16,
478
            torch.float16,
479
        ):
480
            params_attr, optimized_model = utils._model_convert.convert_model_data_type(
481
                optimized_model, dtype
482
            )
483

484
    if opt_properties.optimize_lstm:
485
        replace_lstm_with_ipex_lstm(optimized_model, optimized_optimizer)
486
        torch._dynamo.allow_in_graph(_LSTM)
487

488
    if (
489
        model.training
490
        and opt_properties.split_master_weight_for_bf16
491
        and dtype is torch.bfloat16
492
    ):
493
        if not opt_properties.fuse_update_step:
494
            opt_properties.split_master_weight_for_bf16 = False
495
            warnings.warn(
496
                "IPEX does not non-fused split master weight for bf16 training, "
497
                + "have reset split_master_weight_for_bf16 flag to False. "
498
                + "If you want to use split_master_weight_for_bf16. "
499
                + "Please set both split_master_weight_for_bf16 and fuse_update_step to True."
500
            )
501
        elif (
502
            type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST_CPU
503
            and device_type == "cpu"
504
        ):
505
            opt_properties.split_master_weight_for_bf16 = False
506
            opt_properties.fuse_update_step = False
507
            warnings.warn(
508
                "IPEX CPU does not support fused/fused split update for "
509
                + str(type(optimizer))
510
                + " will use non-fused master weight update for bf16 training on CPU."
511
            )
512
        elif (
513
            type(optimizer) not in IPEX_FUSED_OPTIMIZER_LIST_XPU
514
            and device_type == "xpu"
515
        ):
516
            opt_properties.split_master_weight_for_bf16 = False
517
            opt_properties.fuse_update_step = False
518
            warnings.warn(
519
                "IPEX XPU does not support fused/fused split update for "
520
                + str(type(optimizer))
521
                + " will use non-fused master weight update for bf16 training on XPU."
522
            )
523

524
    if model.training:
525
        if hasattr(optimized_optimizer, "params_attr"):
526
            params_attr = optimized_optimizer.params_attr
527
        if dtype == torch.float16:
528
            assert (
529
                device_type != "xpu"
530
            ), "For now, XPU device does not support model training with half precision."
531
            opt_properties.split_master_weight_for_bf16 = False
532
        if dtype in (torch.bfloat16, torch.float16):
533
            # convert optimizer for training case.
534
            (
535
                optimized_model,
536
                optimized_optimizer,
537
                params_attr,
538
            ) = utils._weight_cast.weight_dtype_convert_with_ipex(
539
                optimized_model,
540
                optimized_optimizer,
541
                params_attr,
542
                opt_properties.split_master_weight_for_bf16,
543
                dtype,
544
            )
545

546
    # Since TorchDynamo cannot handle custom operations yet, for the case of inference graph mode,
547
    # the weights prepacking here is temporarily cancelled, and it will be completed on the graph.
548
    if opt_properties.weights_prepack and device_type == "cpu":
549
        if dtype == torch.bfloat16:
550
            assert core.onednn_has_bf16_support(), (
551
                "BF16 weight prepack needs the cpu support avx512bw, avx512vl and avx512dq, "
552
                + "but the desired instruction sets are not available. "
553
                + "Please set dtype to torch.float or set weights_prepack to False."
554
            )
555
        if dtype == torch.half:
556
            assert core.onednn_has_fp16_support(), (
557
                "FP16 weight prepack needs the cpu support avx512_core_fp16, "
558
                + "but the desired instruction sets are not available. "
559
                + "Please set dtype to torch.float or set weights_prepack to False."
560
            )
561
        (
562
            optimized_model,
563
            optimized_optimizer,
564
            params_attr,
565
        ) = weight_prepack_with_ipex(
566
            optimized_model, optimized_optimizer, params_attr, "cpu"
567
        )
568
        torch._dynamo.allow_in_graph(_IPEXConv1d)
569
        torch._dynamo.allow_in_graph(_IPEXConv2d)
570
        torch._dynamo.allow_in_graph(_IPEXConv3d)
571
        torch._dynamo.allow_in_graph(_IPEXConvTranspose2d)
572
        torch._dynamo.allow_in_graph(_IPEXConvTranspose3d)
573
        torch._dynamo.allow_in_graph(_IPEXLinear)
574

575
    if opt_properties.graph_mode:
576
        _old_forward = optimized_model.forward
577
        wrapper = GraphCapture(
578
            optimized_model,
579
            optimizer is not None,
580
            dtype,
581
            opt_properties.weights_prepack,
582
        )
583
        optimized_model.forward = wrapper(_old_forward)
584

585
    if optimizer is None:
586
        return optimized_model
587

588
    # with an optimizer
589
    if opt_properties.fuse_update_step:
590
        optimized_optimizer = optimizer_fusion(
591
            optimized_optimizer,
592
            opt_properties.split_master_weight_for_bf16,
593
            device_type,
594
        )
595
    return optimized_model, optimized_optimizer
596

597

598
def _convert_convNd_deconvNd_weight_memory_format(module):
599
    # inspired from https://github.com/pytorch/pytorch/blob/master/torch/nn/utils/memory_format.py
600
    if isinstance(module, (torch.nn.Conv1d, torch.nn.ConvTranspose1d)):
601
        weight_data = to_channels_last_1d(module.weight.detach().clone())
602
        module.weight.data = weight_data.resize_(weight_data.size())
603
    elif isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
604
        weight_data = (
605
            module.weight.detach().clone().contiguous(memory_format=torch.channels_last)
606
        )
607
        module.weight.data = weight_data.resize_(
608
            weight_data.size(), memory_format=torch.channels_last
609
        )
610
    elif isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
611
        weight_data = (
612
            module.weight.detach()
613
            .clone()
614
            .contiguous(memory_format=torch.channels_last_3d)
615
        )
616
        module.weight.data = weight_data.resize_(
617
            weight_data.size(), memory_format=torch.channels_last_3d
618
        )
619

620
    for child in module.children():
621
        _convert_convNd_deconvNd_weight_memory_format(child)
622

623

624
class FP32MathMode(IntEnum):
625
    FP32 = int(core.FP32MathMode.FP32)
626
    TF32 = int(core.FP32MathMode.TF32)
627
    BF32 = int(core.FP32MathMode.BF32)
628

629

630
def set_fp32_math_mode(mode=FP32MathMode.FP32, device="cpu"):
631
    r"""
632
    Enable or disable implicit data type conversion.
633

634
    Args:
635
        mode (FP32MathMode): ``FP32MathMode.FP32``, ``FP32MathMode.BF32`` or
636
            ``FP32MathMode.TF32`` (GPU ONLY). oneDNN fpmath mode will be disabled by default if dtype
637
            is set to ``FP32MathMode.FP32``. The implicit ``FP32`` to ``TF32`` data type conversion
638
            will be enabled if dtype is set to ``FP32MathMode.TF32``. The implicit ``FP32``
639
            to ``BF16`` data type conversion will be enabled if dtype is set to ``FP32MathMode.BF32``.
640
        device (string): ``cpu``, ``xpu``
641

642
    Examples:
643

644
        >>> import intel_extension_for_pytorch as ipex
645
        >>> # to enable the implicit data type conversion
646
        >>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32)
647
        >>> # to disable the implicit data type conversion
648
        >>> ipex.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32)
649

650
    ``torch.xpu.set_fp32_math_mode()`` is an alternative function in Intel® Extension for PyTorch*,
651
    to provide identical usage for XPU device only. The motivation of adding this alias is
652
    to unify the coding style in user scripts base on ``torch.xpu`` modular.
653

654
    Examples:
655

656
        >>> import intel_extension_for_pytorch as ipex
657
        >>> # to enable the implicit data type conversion
658
        >>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.BF32)
659
        >>> # to disable the implicit data type conversion
660
        >>> torch.xpu.set_fp32_math_mode(device="xpu", mode=ipex.FP32MathMode.FP32)
661
    """
662

663
    if device == "cpu":
664
        if mode == FP32MathMode.BF32:
665
            core.set_fp32_math_mode(core.FP32MathMode.BF32)
666
        elif mode == FP32MathMode.FP32:
667
            core.set_fp32_math_mode(core.FP32MathMode.FP32)
668
        else:
669
            warnings.warn(
670
                "For CPU device, IPEX does not support mode except \
671
                    FP32MathMode.FP32 and FP32MathMode.BF32 for fpmath_mode right now."
672
            )
673
    elif device == "xpu":
674
        if mode == FP32MathMode.BF32:
675
            torch.xpu.set_fp32_math_mode(torch.xpu.FP32MathMode.BF32)
676
        elif mode == FP32MathMode.FP32:
677
            torch.xpu.set_fp32_math_mode(torch.xpu.FP32MathMode.FP32)
678
        elif mode == FP32MathMode.TF32:
679
            torch.xpu.set_fp32_math_mode(torch.xpu.FP32MathMode.TF32)
680
        else:
681
            warnings.warn(
682
                "For XPU device, IPEX does not support mode except \
683
                    FP32MathMode.FP32, FP32MathMode.BF32 and FP32MathMode.TF32 for fpmath_mode right now."
684
            )
685
    else:
686
        raise RuntimeError(
687
            "Unexpected device type {}. ".format(device) + "Supported are 'cpu', 'xpu'."
688
        )
689

690

691
def get_fp32_math_mode(device="cpu"):
692
    r"""
693
    Get the current fpmath_mode setting.
694

695
    Args:
696
        device (string): ``cpu``, ``xpu``
697

698
    Returns:
699
        Fpmath mode
700
        The value will be ``FP32MathMode.FP32``, ``FP32MathMode.BF32`` or ``FP32MathMode.TF32`` (GPU ONLY).
701
        oneDNN fpmath mode will be disabled by default if dtype is set to ``FP32MathMode.FP32``.
702
        The implicit ``FP32`` to ``TF32`` data type conversion will be enabled if dtype is set
703
        to ``FP32MathMode.TF32``. The implicit ``FP32`` to ``BF16`` data type conversion will be
704
        enabled if dtype is set to ``FP32MathMode.BF32``.
705

706
    Examples:
707

708
        >>> import intel_extension_for_pytorch as ipex
709
        >>> # to get the current fpmath mode
710
        >>> ipex.get_fp32_math_mode(device="xpu")
711

712
    ``torch.xpu.get_fp32_math_mode()`` is an alternative function in Intel® Extension for PyTorch*,
713
    to provide identical usage for XPU device only. The motivation of adding this alias is
714
    to unify the coding style in user scripts base on ``torch.xpu`` modular.
715

716
    Examples:
717

718
        >>> import intel_extension_for_pytorch as ipex
719
        >>> # to get the current fpmath mode
720
        >>> torch.xpu.get_fp32_math_mode(device="xpu")
721
    """
722

723
    if device == "cpu":
724
        return core.get_fp32_math_mode()
725
    elif device == "xpu":
726
        return torch.xpu.get_fp32_math_mode()
727
    else:
728
        raise RuntimeError(
729
            "Unexpected device type {}. ".format(device) + "Supported are 'cpu', 'xpu'."
730
        )
731

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.