pytorch

Форк
0
47 строк · 1.9 Кб
1
import torch
2

3

4
class WrapperModule:
5
    """Wraps the instance of wrapped_type.
6
    For graph_mode traces the instance of wrapped_type.
7
    Randomaly initializes num_params tensors with single float element.
8
    Args:
9
        wrapped_type:
10
            - Object type to be wrapped.
11
                Expects the wrapped_type to:
12
                   - be constructed with pt_fn specified in module_config.
13
                   - provide forward method that takes module_config.num_params args.
14
        module_config:
15
            - Specified pt_fn to construct wrapped_type with, whether graph_mode
16
              is enabled, and number of parameters wrapped_type's forward method
17
              takes.
18
        debug:
19
            - Whether debug mode is enabled.
20
        save:
21
            - In graph mode, whether graph is to be saved.
22
    """
23

24
    def __init__(self, wrapped_type, module_config, debug, save=False):
25
        pt_fn = module_config.pt_fn
26
        self.module = wrapped_type(pt_fn)
27
        self.tensor_inputs = []
28
        self.module_name = wrapped_type.__name__
29
        for _ in range(module_config.num_params):
30
            self.tensor_inputs.append(torch.randn(1))
31
        if module_config.graph_mode:
32
            self.module = torch.jit.trace(self.module, self.tensor_inputs)
33
            if save:
34
                file_name = self.module_name + "_" + pt_fn.__name__ + ".pt"
35
                torch.jit.save(self.module, file_name)
36
                print(f"Generated graph is saved in {file_name}")
37
        print(
38
            f"Benchmarking module {self.module_name} with fn {pt_fn.__name__}: Graph mode:{module_config.graph_mode}"
39
        )
40
        if debug and isinstance(self.module, torch.jit.ScriptModule):
41
            print(self.module.graph)
42
            print(self.module.code)
43

44
    def forward(self, niters):
45
        with torch.no_grad():
46
            for _ in range(niters):
47
                self.module.forward(*self.tensor_inputs)
48

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.