lama
177 строк · 5.2 Кб
1import bisect
2import functools
3import logging
4import numbers
5import os
6import signal
7import sys
8import traceback
9import warnings
10
11import torch
12from pytorch_lightning import seed_everything
13
14LOGGER = logging.getLogger(__name__)
15
16import platform
17if platform.system() != 'Linux':
18signal.SIGUSR1 = 1
19
20def check_and_warn_input_range(tensor, min_value, max_value, name):
21actual_min = tensor.min()
22actual_max = tensor.max()
23if actual_min < min_value or actual_max > max_value:
24warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
25
26
27def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
28for k, v in cur_dict.items():
29target_key = prefix + k
30target[target_key] = target.get(target_key, default) + v
31
32
33def average_dicts(dict_list):
34result = {}
35norm = 1e-3
36for dct in dict_list:
37sum_dict_with_prefix(result, dct, '')
38norm += 1
39for k in list(result):
40result[k] /= norm
41return result
42
43
44def add_prefix_to_keys(dct, prefix):
45return {prefix + k: v for k, v in dct.items()}
46
47
48def set_requires_grad(module, value):
49for param in module.parameters():
50param.requires_grad = value
51
52
53def flatten_dict(dct):
54result = {}
55for k, v in dct.items():
56if isinstance(k, tuple):
57k = '_'.join(k)
58if isinstance(v, dict):
59for sub_k, sub_v in flatten_dict(v).items():
60result[f'{k}_{sub_k}'] = sub_v
61else:
62result[k] = v
63return result
64
65
66class LinearRamp:
67def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
68self.start_value = start_value
69self.end_value = end_value
70self.start_iter = start_iter
71self.end_iter = end_iter
72
73def __call__(self, i):
74if i < self.start_iter:
75return self.start_value
76if i >= self.end_iter:
77return self.end_value
78part = (i - self.start_iter) / (self.end_iter - self.start_iter)
79return self.start_value * (1 - part) + self.end_value * part
80
81
82class LadderRamp:
83def __init__(self, start_iters, values):
84self.start_iters = start_iters
85self.values = values
86assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
87
88def __call__(self, i):
89segment_i = bisect.bisect_right(self.start_iters, i)
90return self.values[segment_i]
91
92
93def get_ramp(kind='ladder', **kwargs):
94if kind == 'linear':
95return LinearRamp(**kwargs)
96if kind == 'ladder':
97return LadderRamp(**kwargs)
98raise ValueError(f'Unexpected ramp kind: {kind}')
99
100
101def print_traceback_handler(sig, frame):
102LOGGER.warning(f'Received signal {sig}')
103bt = ''.join(traceback.format_stack())
104LOGGER.warning(f'Requested stack trace:\n{bt}')
105
106
107def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler):
108LOGGER.warning(f'Setting signal {sig} handler {handler}')
109signal.signal(sig, handler)
110
111
112def handle_deterministic_config(config):
113seed = dict(config).get('seed', None)
114if seed is None:
115return False
116
117seed_everything(seed)
118return True
119
120
121def get_shape(t):
122if torch.is_tensor(t):
123return tuple(t.shape)
124elif isinstance(t, dict):
125return {n: get_shape(q) for n, q in t.items()}
126elif isinstance(t, (list, tuple)):
127return [get_shape(q) for q in t]
128elif isinstance(t, numbers.Number):
129return type(t)
130else:
131raise ValueError('unexpected type {}'.format(type(t)))
132
133
134def get_has_ddp_rank():
135master_port = os.environ.get('MASTER_PORT', None)
136node_rank = os.environ.get('NODE_RANK', None)
137local_rank = os.environ.get('LOCAL_RANK', None)
138world_size = os.environ.get('WORLD_SIZE', None)
139has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
140return has_rank
141
142
143def handle_ddp_subprocess():
144def main_decorator(main_func):
145@functools.wraps(main_func)
146def new_main(*args, **kwargs):
147# Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
148parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
149has_parent = parent_cwd is not None
150has_rank = get_has_ddp_rank()
151assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
152
153if has_parent:
154# we are in the worker
155sys.argv.extend([
156f'hydra.run.dir={parent_cwd}',
157# 'hydra/hydra_logging=disabled',
158# 'hydra/job_logging=disabled'
159])
160# do nothing if this is a top-level process
161# TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
162
163main_func(*args, **kwargs)
164return new_main
165return main_decorator
166
167
168def handle_ddp_parent_process():
169parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
170has_parent = parent_cwd is not None
171has_rank = get_has_ddp_rank()
172assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
173
174if parent_cwd is None:
175os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
176
177return has_parent
178