intel-extension-for-pytorch

Форк
0
49 строк · 1016.0 Байт
1
import torch
2
import intel_extension_for_pytorch as ipex
3

4
torch_function = [
5
    "rand",
6
    "randint",
7
    "arange",
8
    "bartlett_window",
9
    "blackman_window",
10
    "empty",
11
    "_empty_affine_quantized",
12
    "_empty_per_channel_affine_quantized",
13
    "empty_strided",
14
    "eye",
15
    "full",
16
    "from_file",
17
    "from_numpy",
18
    "hann_window",
19
    "hamming_window",
20
    "linspace",
21
    "logspace",
22
    "ones",
23
    "scalar_tensor",
24
    "randn",
25
    "randperm",
26
    "range",
27
    "zeros",
28
    "sparse_coo_tensor",
29
    "tril_indices",
30
    "triu_indices",
31
    "normal",
32
    "tensor",
33
]
34

35

36
def make_hooked_func(torch_func):
37
    def hooked_func(*args, **kwargs):
38
        if "device" in kwargs:
39
            return torch_func(*args, **kwargs)
40
        else:
41
            return torch_func(*args, **kwargs).to(ipex.DEVICE)
42

43
    return hooked_func
44

45

46
for torch_func_name in torch_function:
47
    torch_fn = getattr(torch, torch_func_name)
48
    hooked_fn = make_hooked_func(torch_fn)
49
    setattr(torch, torch_func_name, hooked_fn)
50

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.