pytorch
47 строк · 1.9 Кб
1import torch
2
3
4class WrapperModule:
5"""Wraps the instance of wrapped_type.
6For graph_mode traces the instance of wrapped_type.
7Randomaly initializes num_params tensors with single float element.
8Args:
9wrapped_type:
10- Object type to be wrapped.
11Expects the wrapped_type to:
12- be constructed with pt_fn specified in module_config.
13- provide forward method that takes module_config.num_params args.
14module_config:
15- Specified pt_fn to construct wrapped_type with, whether graph_mode
16is enabled, and number of parameters wrapped_type's forward method
17takes.
18debug:
19- Whether debug mode is enabled.
20save:
21- In graph mode, whether graph is to be saved.
22"""
23
24def __init__(self, wrapped_type, module_config, debug, save=False):
25pt_fn = module_config.pt_fn
26self.module = wrapped_type(pt_fn)
27self.tensor_inputs = []
28self.module_name = wrapped_type.__name__
29for _ in range(module_config.num_params):
30self.tensor_inputs.append(torch.randn(1))
31if module_config.graph_mode:
32self.module = torch.jit.trace(self.module, self.tensor_inputs)
33if save:
34file_name = self.module_name + "_" + pt_fn.__name__ + ".pt"
35torch.jit.save(self.module, file_name)
36print(f"Generated graph is saved in {file_name}")
37print(
38f"Benchmarking module {self.module_name} with fn {pt_fn.__name__}: Graph mode:{module_config.graph_mode}"
39)
40if debug and isinstance(self.module, torch.jit.ScriptModule):
41print(self.module.graph)
42print(self.module.code)
43
44def forward(self, niters):
45with torch.no_grad():
46for _ in range(niters):
47self.module.forward(*self.tensor_inputs)
48