lama
28 строк · 689.0 Байт
1from enum import Enum
2
3import yaml
4from easydict import EasyDict as edict
5import torch.nn as nn
6import torch
7
8
9def load_yaml(path):
10with open(path, 'r') as f:
11return edict(yaml.safe_load(f))
12
13
14def move_to_device(obj, device):
15if isinstance(obj, nn.Module):
16return obj.to(device)
17if torch.is_tensor(obj):
18return obj.to(device)
19if isinstance(obj, (tuple, list)):
20return [move_to_device(el, device) for el in obj]
21if isinstance(obj, dict):
22return {name: move_to_device(val, device) for name, val in obj.items()}
23raise ValueError(f'Unexpected type {type(obj)}')
24
25
26class SmallMode(Enum):
27DROP = "drop"
28UPSCALE = "upscale"
29