2
Helper functions for performing coord check.
6
from itertools import product
11
import torch.nn.functional as F
13
from mup import coord_check as mup_coord_check
14
from megatron.training import train_step
30
filter_module_by_name=None,
42
for i in range(nseeds):
44
for width, model in models.items():
47
optimizer = optcls(model)
48
for step in range(nsteps + 1):
51
for name, module in model.named_modules():
52
if filter_module_by_name and not filter_module_by_name(name):
55
module.register_forward_hook(
56
mup_coord_check._record_coords(
61
output_fdict=output_fdict,
62
input_fdict=input_fdict,
63
param_fdict=param_fdict,
69
loss_dict, skipped_iter = train_step(
72
data_iterator=dataloader,
75
lr_scheduler=lr_scheduler,
79
for handle in remove_hooks:
87
return pd.DataFrame(df)
99
filter_trainable_by_name=None,
102
"""Get coord data for coord check.
103
Train the models in `models` with data from `dataloader` and optimizer
104
specified by `optimizer` and `lr` for `nsteps` steps, and record coordinate
105
statistics specified by `output_fdict`, `input_fdict`, `param_fdict`. By
106
default, only `l1` is computed for output activations of each module.
107
This function wraps around `_get_coord_data`, with the main difference being
108
user can specify common optimizers via a more convenient interface.
111
a dict of lazy models, where the keys are numbers indicating width.
112
Each entry of `models` is a function that instantiates a model given
115
an iterator whose elements are either Huggingface style dicts, if
116
`dict_in_out` is True, or (input, label). If `fix_data` is True
117
(which is the default), then only the first element of `dataloader`
118
is used in a loop and the rest of `dataloder` is ignored.
120
a string in `['sgd', 'adam', 'adamw']`, with default being `'sgd'`.
122
learning rate. By default is 0.1 for `'sgd'` and 1e-3 for others.
124
If True, then use the optimizer from `mup.optim`; otherwise, use the
125
one from `torch.optim`.
126
filter_trainable_by_name:
127
a function that returns a bool given module names (from
128
`model.named_modules()`), or None. If not None, then only modules
129
whose name yields True will be trained.
131
number of steps to train the model
133
whether the data loader contains Huggingface-style dict input and
134
output. Default: False
136
if not `dict_in_out`, reshape the input to be
137
`input.view(input.shape[0], -1)`. Typically used for testing MLPs.
139
if not `dict_in_out`, reshape the label to be `label.view(-1,
142
if `dict_in_out`, this is the key for the loss value if the output
143
is a dict. If the output is not a dict, then we assume the first
144
element of the output is the loss.
146
loss function to use if not `dict_in_out`. Can be either a string from
147
[`xent`, 'mse', 'nll', 'l1'] or a python `callable` such that
148
`lossfn(output, target)` returns the loss value. Examples of valid
149
`callable`s are `F.cross_entropy`, `F.mse_loss`, etc, where `F` is
150
`torch.nn.functional`. Default: 'xent'
151
filter_module_by_name:
152
a function that returns a bool given module names (from
153
`model.named_modules()`), or None. If not None, then only modules
154
whose name yields True will be recorded.
156
whether to use cuda or not. Default: True
158
number of times to repeat the training, each with different seeds.
159
output_fdict, input_fdict, param_fdict:
160
function dicts to be used in `_record_coords`. By default, only `l1`
161
is computed for output activations of each module.
163
show progress using tqdm. Default: True
165
convert target label into a one-hot vector. This typically is only
166
used for `'mse'` or `'l1'` losses in classification tasks.
169
a pandas DataFrame containing recorded results. The column names are
170
`'width', 'module', 't'` as well as names of statistics recorded, such
171
as `'l1'` (see `FDICT` for other premade statistics that can be
175
In v1.0.0, when `lossfn=='mse'`, the target is automatically converted
176
to a one hot vector before loss computation. Starting in v1.1.0, this
177
behavior is turned off, and the user needs to explicitly turn on this
178
behavior by setting `one_hot_target=True`.
181
lr = 0.1 if optimizer == "sgd" else 1e-3
183
from mup.optim import MuAdam as Adam
184
from mup.optim import MuAdamW as AdamW
185
from mup.optim import MuSGD as SGD
187
from torch.optim import SGD, Adam, AdamW
189
def get_trainable(model):
190
params = model.parameters()
191
if filter_trainable_by_name is not None:
193
for name, p in model.named_parameters():
194
if filter_trainable_by_name(name):
198
if optimizer == "sgd":
199
optcls = lambda model: SGD(get_trainable(model), lr=lr)
200
elif optimizer == "adam":
201
optcls = lambda model: Adam(get_trainable(model), lr=lr)
202
elif optimizer == "adamw":
203
optcls = lambda model: AdamW(get_trainable(model), lr=lr)
204
elif optimizer is None:
205
raise ValueError("optimizer should be sgd|adam|adamw or a custom function")
207
data = _get_coord_data(
208
neox_args, timers, lr_scheduler, models, dataloader, optcls, **kwargs
210
data["optimizer"] = optimizer