gpt-neox

Форк
0
/
mup_substitute.py 
212 строк · 7.6 Кб
1
"""
2
Helper functions for performing coord check.
3
"""
4
import os
5
from copy import copy
6
from itertools import product
7

8
import numpy as np
9
import pandas as pd
10
import torch
11
import torch.nn.functional as F
12

13
from mup import coord_check as mup_coord_check
14
from megatron.training import train_step
15

16

17
def _get_coord_data(
18
    neox_args,
19
    timers,
20
    lr_scheduler,
21
    models,
22
    dataloader,
23
    optcls,
24
    nsteps=3,
25
    dict_in_out=False,
26
    flatten_input=False,
27
    flatten_output=False,
28
    output_name="loss",
29
    lossfn="xent",
30
    filter_module_by_name=None,
31
    fix_data=True,
32
    cuda=True,
33
    nseeds=1,
34
    output_fdict=None,
35
    input_fdict=None,
36
    param_fdict=None,
37
    show_progress=True,
38
    one_hot_target=False,
39
):
40
    df = []
41

42
    for i in range(nseeds):
43
        torch.manual_seed(i)
44
        for width, model in models.items():
45
            model = model()
46
            model.train()
47
            optimizer = optcls(model)
48
            for step in range(nsteps + 1):
49
                remove_hooks = []
50
                # add hooks
51
                for name, module in model.named_modules():
52
                    if filter_module_by_name and not filter_module_by_name(name):
53
                        continue
54
                    remove_hooks.append(
55
                        module.register_forward_hook(
56
                            mup_coord_check._record_coords(
57
                                df,
58
                                width,
59
                                name,
60
                                step + 1,
61
                                output_fdict=output_fdict,
62
                                input_fdict=input_fdict,
63
                                param_fdict=param_fdict,
64
                            )
65
                        )
66
                    )
67

68
                # train for a step
69
                loss_dict, skipped_iter = train_step(
70
                    neox_args=neox_args,
71
                    timers=timers,
72
                    data_iterator=dataloader,
73
                    model=model,
74
                    optimizer=optimizer,
75
                    lr_scheduler=lr_scheduler,
76
                )
77

78
                # remove hooks
79
                for handle in remove_hooks:
80
                    handle.remove()
81

82
            import gc
83

84
            del model
85
            gc.collect()
86

87
    return pd.DataFrame(df)
88

89

90
def get_coord_data(
91
    neox_args,
92
    timers,
93
    lr_scheduler,
94
    models,
95
    dataloader,
96
    optimizer="sgd",
97
    lr=None,
98
    mup=True,
99
    filter_trainable_by_name=None,
100
    **kwargs
101
):
102
    """Get coord data for coord check.
103
    Train the models in `models` with data from `dataloader` and optimizer
104
    specified by `optimizer` and `lr` for `nsteps` steps, and record coordinate
105
    statistics specified by `output_fdict`, `input_fdict`, `param_fdict`. By
106
    default, only `l1` is computed for output activations of each module.
107
    This function wraps around `_get_coord_data`, with the main difference being
108
    user can specify common optimizers via a more convenient interface.
109
    Inputs:
110
        models:
111
            a dict of lazy models, where the keys are numbers indicating width.
112
            Each entry of `models` is a function that instantiates a model given
113
            nothing.
114
        dataloader:
115
            an iterator whose elements are either Huggingface style dicts, if
116
            `dict_in_out` is True, or (input, label). If `fix_data` is True
117
            (which is the default), then only the first element of `dataloader`
118
            is used in a loop and the rest of `dataloder` is ignored.
119
        optimizer:
120
            a string in `['sgd', 'adam', 'adamw']`, with default being `'sgd'`.
121
        lr:
122
            learning rate. By default is 0.1 for `'sgd'` and 1e-3 for others.
123
        mup:
124
            If True, then use the optimizer from `mup.optim`; otherwise, use the
125
            one from `torch.optim`.
126
        filter_trainable_by_name:
127
            a function that returns a bool given module names (from
128
            `model.named_modules()`), or None. If not None, then only modules
129
            whose name yields True will be trained.
130
        nsteps:
131
            number of steps to train the model
132
        dict_in_out:
133
            whether the data loader contains Huggingface-style dict input and
134
            output. Default: False
135
        flatten_input:
136
            if not `dict_in_out`, reshape the input to be
137
            `input.view(input.shape[0], -1)`. Typically used for testing MLPs.
138
        flatten_output:
139
            if not `dict_in_out`, reshape the label to be `label.view(-1,
140
            input.shape[-1])`.
141
        output_name:
142
            if `dict_in_out`, this is the key for the loss value if the output
143
            is a dict. If the output is not a dict, then we assume the first
144
            element of the output is the loss.
145
        lossfn:
146
            loss function to use if not `dict_in_out`. Can be either a string from
147
            [`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that
148
            `lossfn(output, target)` returns the loss value. Examples of valid
149
            `callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is
150
            `torch.nn.functional`. Default: 'xent'
151
        filter_module_by_name:
152
            a function that returns a bool given module names (from
153
            `model.named_modules()`), or None. If not None, then only modules
154
            whose name yields True will be recorded.
155
        cuda:
156
            whether to use cuda or not. Default: True
157
        nseeds:
158
            number of times to repeat the training, each with different seeds.
159
        output_fdict, input_fdict, param_fdict:
160
            function dicts to be used in `_record_coords`. By default, only `l1`
161
            is computed for output activations of each module.
162
        show_progress:
163
            show progress using tqdm. Default: True
164
        one_hot_target:
165
            convert target label into a one-hot vector. This typically is only
166
            used for `'mse'` or `'l1'` losses in classification tasks.
167
            Default: False
168
    Output:
169
        a pandas DataFrame containing recorded results. The column names are
170
        `'width', 'module', 't'` as well as names of statistics recorded, such
171
        as `'l1'` (see `FDICT` for other premade statistics that can be
172
        collected).
173

174
    Breaking Changes:
175
        In v1.0.0, when `lossfn=='mse'`, the target is automatically converted
176
        to a one hot vector before loss computation. Starting in v1.1.0, this
177
        behavior is turned off, and the user needs to explicitly turn on this
178
        behavior by setting `one_hot_target=True`.
179
    """
180
    if lr is None:
181
        lr = 0.1 if optimizer == "sgd" else 1e-3
182
    if mup:
183
        from mup.optim import MuAdam as Adam
184
        from mup.optim import MuAdamW as AdamW
185
        from mup.optim import MuSGD as SGD
186
    else:
187
        from torch.optim import SGD, Adam, AdamW
188

189
    def get_trainable(model):
190
        params = model.parameters()
191
        if filter_trainable_by_name is not None:
192
            params = []
193
            for name, p in model.named_parameters():
194
                if filter_trainable_by_name(name):
195
                    params.append(p)
196
        return params
197

198
    if optimizer == "sgd":
199
        optcls = lambda model: SGD(get_trainable(model), lr=lr)
200
    elif optimizer == "adam":
201
        optcls = lambda model: Adam(get_trainable(model), lr=lr)
202
    elif optimizer == "adamw":
203
        optcls = lambda model: AdamW(get_trainable(model), lr=lr)
204
    elif optimizer is None:
205
        raise ValueError("optimizer should be sgd|adam|adamw or a custom function")
206

207
    data = _get_coord_data(
208
        neox_args, timers, lr_scheduler, models, dataloader, optcls, **kwargs
209
    )
210
    data["optimizer"] = optimizer
211
    data["lr"] = lr
212
    return data
213

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.