lama

Форк
0
/
utils.py 
177 строк · 5.2 Кб
1
import bisect
2
import functools
3
import logging
4
import numbers
5
import os
6
import signal
7
import sys
8
import traceback
9
import warnings
10

11
import torch
12
from pytorch_lightning import seed_everything
13

14
LOGGER = logging.getLogger(__name__)
15

16
import platform
17
if platform.system() != 'Linux':
18
    signal.SIGUSR1 = 1
19

20
def check_and_warn_input_range(tensor, min_value, max_value, name):
21
    actual_min = tensor.min()
22
    actual_max = tensor.max()
23
    if actual_min < min_value or actual_max > max_value:
24
        warnings.warn(f"{name} must be in {min_value}..{max_value} range, but it ranges {actual_min}..{actual_max}")
25

26

27
def sum_dict_with_prefix(target, cur_dict, prefix, default=0):
28
    for k, v in cur_dict.items():
29
        target_key = prefix + k
30
        target[target_key] = target.get(target_key, default) + v
31

32

33
def average_dicts(dict_list):
34
    result = {}
35
    norm = 1e-3
36
    for dct in dict_list:
37
        sum_dict_with_prefix(result, dct, '')
38
        norm += 1
39
    for k in list(result):
40
        result[k] /= norm
41
    return result
42

43

44
def add_prefix_to_keys(dct, prefix):
45
    return {prefix + k: v for k, v in dct.items()}
46

47

48
def set_requires_grad(module, value):
49
    for param in module.parameters():
50
        param.requires_grad = value
51

52

53
def flatten_dict(dct):
54
    result = {}
55
    for k, v in dct.items():
56
        if isinstance(k, tuple):
57
            k = '_'.join(k)
58
        if isinstance(v, dict):
59
            for sub_k, sub_v in flatten_dict(v).items():
60
                result[f'{k}_{sub_k}'] = sub_v
61
        else:
62
            result[k] = v
63
    return result
64

65

66
class LinearRamp:
67
    def __init__(self, start_value=0, end_value=1, start_iter=-1, end_iter=0):
68
        self.start_value = start_value
69
        self.end_value = end_value
70
        self.start_iter = start_iter
71
        self.end_iter = end_iter
72

73
    def __call__(self, i):
74
        if i < self.start_iter:
75
            return self.start_value
76
        if i >= self.end_iter:
77
            return self.end_value
78
        part = (i - self.start_iter) / (self.end_iter - self.start_iter)
79
        return self.start_value * (1 - part) + self.end_value * part
80

81

82
class LadderRamp:
83
    def __init__(self, start_iters, values):
84
        self.start_iters = start_iters
85
        self.values = values
86
        assert len(values) == len(start_iters) + 1, (len(values), len(start_iters))
87

88
    def __call__(self, i):
89
        segment_i = bisect.bisect_right(self.start_iters, i)
90
        return self.values[segment_i]
91

92

93
def get_ramp(kind='ladder', **kwargs):
94
    if kind == 'linear':
95
        return LinearRamp(**kwargs)
96
    if kind == 'ladder':
97
        return LadderRamp(**kwargs)
98
    raise ValueError(f'Unexpected ramp kind: {kind}')
99

100

101
def print_traceback_handler(sig, frame):
102
    LOGGER.warning(f'Received signal {sig}')
103
    bt = ''.join(traceback.format_stack())
104
    LOGGER.warning(f'Requested stack trace:\n{bt}')
105

106

107
def register_debug_signal_handlers(sig=signal.SIGUSR1, handler=print_traceback_handler):
108
    LOGGER.warning(f'Setting signal {sig} handler {handler}')
109
    signal.signal(sig, handler)
110

111

112
def handle_deterministic_config(config):
113
    seed = dict(config).get('seed', None)
114
    if seed is None:
115
        return False
116

117
    seed_everything(seed)
118
    return True
119

120

121
def get_shape(t):
122
    if torch.is_tensor(t):
123
        return tuple(t.shape)
124
    elif isinstance(t, dict):
125
        return {n: get_shape(q) for n, q in t.items()}
126
    elif isinstance(t, (list, tuple)):
127
        return [get_shape(q) for q in t]
128
    elif isinstance(t, numbers.Number):
129
        return type(t)
130
    else:
131
        raise ValueError('unexpected type {}'.format(type(t)))
132

133

134
def get_has_ddp_rank():
135
    master_port = os.environ.get('MASTER_PORT', None)
136
    node_rank = os.environ.get('NODE_RANK', None)
137
    local_rank = os.environ.get('LOCAL_RANK', None)
138
    world_size = os.environ.get('WORLD_SIZE', None)
139
    has_rank = master_port is not None or node_rank is not None or local_rank is not None or world_size is not None
140
    return has_rank
141

142

143
def handle_ddp_subprocess():
144
    def main_decorator(main_func):
145
        @functools.wraps(main_func)
146
        def new_main(*args, **kwargs):
147
            # Trainer sets MASTER_PORT, NODE_RANK, LOCAL_RANK, WORLD_SIZE
148
            parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
149
            has_parent = parent_cwd is not None
150
            has_rank = get_has_ddp_rank()
151
            assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
152

153
            if has_parent:
154
                # we are in the worker
155
                sys.argv.extend([
156
                    f'hydra.run.dir={parent_cwd}',
157
                    # 'hydra/hydra_logging=disabled',
158
                    # 'hydra/job_logging=disabled'
159
                ])
160
            # do nothing if this is a top-level process
161
            # TRAINING_PARENT_WORK_DIR is set in handle_ddp_parent_process after hydra initialization
162

163
            main_func(*args, **kwargs)
164
        return new_main
165
    return main_decorator
166

167

168
def handle_ddp_parent_process():
169
    parent_cwd = os.environ.get('TRAINING_PARENT_WORK_DIR', None)
170
    has_parent = parent_cwd is not None
171
    has_rank = get_has_ddp_rank()
172
    assert has_parent == has_rank, f'Inconsistent state: has_parent={has_parent}, has_rank={has_rank}'
173

174
    if parent_cwd is None:
175
        os.environ['TRAINING_PARENT_WORK_DIR'] = os.getcwd()
176

177
    return has_parent
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.