google-research

Форк
0
310 строк · 10.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Example of training the SLiMPerformer on PennTreeBank and Enwik8 data, as well as the copy task."""
17
import collections
18
import gzip
19
import os
20
import time
21

22
from absl import app
23
from absl import flags
24
from absl import logging
25

26
import numpy as np
27
import torch
28

29
from performer.models.pytorch.slim_performer import slim_performer_model
30

31
FLAGS = flags.FLAGS
32

33
# Model parameters
34
flags.DEFINE_integer('batch_size', 10, 'Batch size for training.')
35
flags.DEFINE_float('learning_rate', 1e-4, 'Adam Optimizer learning rate.')
36
flags.DEFINE_integer(
37
    'step_size', -1,
38
    'Tradeoff between memory and parallel running time (C). -1 corresponds to naive FULL method in the paper.'
39
)
40
flags.DEFINE_integer('hidden_dim', 512, 'Feature dimension.')
41
flags.DEFINE_integer('n_layers', 6, 'Number of Attention layers.')
42
flags.DEFINE_integer('ffn_dim', 2048, 'MLP dimension in model.')
43
flags.DEFINE_integer('n_heads', 8, 'Number of heads for attention.')
44
flags.DEFINE_string(
45
    'feature_type', 'sqr',
46
    'Nonlinearity function for feature. Can be relu, elu+1, sqr, favor+, or favor+{int}.'
47
)
48
flags.DEFINE_enum(
49
    'compute_type', 'iter', ['iter', 'ps', 'parallel_ps'],
50
    'Which type of method to compute: iter = iterative algorithm from Appendix B, ps = implementation using torch.cumsum, parallel_ps = implementation using custom log prefix sum implementation.'
51
)
52
flags.DEFINE_float('weight_decay', 0.0, 'Weight decay for regularization.')
53

54
# Training parameters
55
flags.DEFINE_integer('iters_count', 100000, 'Number of training iterations.')
56
flags.DEFINE_bool('finetune', True, '')
57
flags.DEFINE_bool('on_gptln', True, 'Use layer norm after attention or before.')
58
flags.DEFINE_string('gpu_id', '0', 'ID of GPU.')
59
flags.DEFINE_integer(
60
    'arg_code', -1,
61
    'If -1, uses user-defined FLAGS. Else uses predetermined flag values to reproduce paper results.'
62
)
63
flags.DEFINE_integer('random_seed', 42, 'Random seed for both Numpy and Torch.')
64
flags.DEFINE_integer('val_step', 5, 'Interval to predict validation metrics.')
65
flags.DEFINE_integer('print_step', 100, 'Interval to print metrics.')
66

67
# Dataset parameters
68
flags.DEFINE_enum('dataset', 'ptb', ['ptb', 'enwik8', 'copy'],
69
                  'Dataset to use.')
70
flags.DEFINE_integer('seq_len', 512, 'Maximum sequence length (L).')
71
flags.DEFINE_integer('vocab_size', 256, 'Vocabulary size of data.')
72

73

74
def get_batch(data, batch_size, seq_len, index):
75
  """Batch the data."""
76
  elems_in_batch = batch_size * seq_len
77
  batches_count = len(data) // elems_in_batch
78

79
  batch_start = elems_in_batch * (index % batches_count)
80
  batch_end = batch_start + elems_in_batch
81

82
  batch = data[batch_start:batch_end]
83
  batch = batch.reshape(batch_size, seq_len)
84

85
  return batch
86

87

88
def get_batch_copy(vocab_size, batch_size, seq_len):
89
  """Generates random data for copying."""
90
  batch = np.random.choice(
91
      vocab_size - 1, size=[batch_size, seq_len // 2 - 1]) + 1
92
  batch = np.concatenate([np.zeros([batch_size, 1], dtype=int), batch], axis=1)
93
  batch = np.concatenate([batch] * 2, axis=1)
94

95
  batch_mask = np.concatenate([
96
      np.zeros([batch_size, seq_len // 2], dtype=bool),
97
      np.ones([batch_size, seq_len // 2], dtype=bool)
98
  ],
99
                              axis=1)
100

101
  return batch, batch_mask
102

103

104
def get_enwik8():
105
  """Download here: http://prize.hutter1.net/ and put into /data/ folder."""
106
  with gzip.open('./data/enwik8.gz') as f:
107
    data = np.fromstring(f.read(int(95e6)), dtype=np.uint8)
108

109
  train_data, val_data = np.split(data, [int(90e6)])
110

111
  return train_data, val_data
112

113

114
def get_ptb():
115
  """Download here: https://github.com/wojzaremba/lstm/tree/master/data and put into /data/ folder."""
116
  with open('./data/ptb.train.txt', 'r') as f:
117
    train_data = np.fromstring(f.read(), dtype=np.uint8)
118

119
  with open('./data/ptb.valid.txt', 'r') as f:
120
    val_data = np.fromstring(f.read(), dtype=np.uint8)
121

122
  return train_data, val_data
123

124

125
def set_default_flags(arg_code):
126
  """Sets default arguments used in paper."""
127
  possible_flags = []
128
  obj_class = collections.namedtuple(
129
      'obj',
130
      'dataset seq_len batch_size learning_rate step_size hidden_dim n_layers ffn_dim n_heads feature_type compute_type weight_decay iters_count finetune on_gptln'
131
  )
132

133
  for step_size, finetune in [(-1, False), (512, False), (256, False),
134
                              (512, True), (256, True)]:
135
    obj = obj_class(
136
        dataset='ptb',
137
        seq_len=1024,
138
        batch_size=1,
139
        learning_rate=1e-4,
140
        step_size=step_size,
141
        hidden_dim=512,
142
        n_layers=3,
143
        ffn_dim=2048,
144
        n_heads=8,
145
        feature_type='sqr',
146
        compute_type='iter',
147
        weight_decay=0.0,
148
        iters_count=30000,
149
        finetune=finetune,
150
        on_gptln=True)
151

152
    possible_flags.append(obj)
153

154
  for step_size, finetune in [(-1, False), (2048, False), (1366, False),
155
                              (2048, True), (1366, True)]:
156
    obj = obj_class(
157
        dataset='enwik8',
158
        seq_len=4096,
159
        batch_size=1,
160
        learning_rate=2e-4,
161
        step_size=step_size,
162
        hidden_dim=1024,
163
        n_layers=3,
164
        ffn_dim=4096,
165
        n_heads=16,
166
        feature_type='sqr',
167
        compute_type='iter',
168
        weight_decay=0.0,
169
        iters_count=100000,
170
        finetune=finetune,
171
        on_gptln=True)
172

173
    possible_flags.append(obj)
174

175
  for step_size, finetune in [(-1, False), (128, False), (64, False),
176
                              (128, True), (64, True)]:
177
    obj = obj_class(
178
        dataset='copy',
179
        seq_len=512,
180
        batch_size=1,
181
        learning_rate=1e-2,
182
        step_size=step_size,
183
        hidden_dim=256,
184
        n_layers=1,
185
        ffn_dim=1024,
186
        n_heads=4,
187
        feature_type='sqr',
188
        compute_type='iter',
189
        weight_decay=0.0,
190
        iters_count=15000,
191
        finetune=finetune,
192
        on_gptln=False)
193

194
    possible_flags.append(obj)
195

196
  chosen_flags = possible_flags[arg_code]
197

198
  for flag_name, flag_value in chosen_flags._asdict().iteritems():
199
    setattr(FLAGS, flag_name, flag_value)
200

201

202
def main(_):
203

204
  os.environ['CUDA_VISIBLE_DEVICES'] = FLAGS.gpu_id
205

206
  if FLAGS.arg_code != -1:
207
    set_default_flags(FLAGS.arg_code)
208

209
  np.random.seed(FLAGS.random_seed)
210
  torch.manual_seed(FLAGS.random_seed)
211

212
  if FLAGS.dataset == 'enwik8':
213
    train_data, val_data = get_enwik8()
214
  elif FLAGS.dataset == 'ptb':
215
    train_data, val_data = get_ptb()
216

217
  logging.info('Data loaded: %d train chars, %d val chars', len(train_data),
218
               len(val_data))
219

220
  model = slim_performer_model.SLiMPerformer(FLAGS.vocab_size, FLAGS.hidden_dim,
221
                                             FLAGS.n_layers, FLAGS.ffn_dim,
222
                                             FLAGS.n_heads, FLAGS.feature_type,
223
                                             FLAGS.compute_type,
224
                                             FLAGS.on_gptln).cuda()
225

226
  optimizer = torch.optim.Adam(
227
      model.parameters(),
228
      lr=FLAGS.learning_rate,
229
      weight_decay=FLAGS.weight_decay)
230

231
  if FLAGS.dataset == 'copy':
232
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
233
        optimizer, milestones=[10000], gamma=0.1)
234

235
  training_start = time.time()
236

237
  if FLAGS.dataset == 'copy':
238
    model.train()
239

240
  for train_index in range(FLAGS.iters_count):
241
    if FLAGS.dataset != 'copy':
242
      model.train()
243

244
    if FLAGS.dataset == 'copy':
245
      train_batch, mask = get_batch_copy(FLAGS.vocab_size, FLAGS.batch_size,
246
                                         FLAGS.seq_len)
247
      mask = torch.from_numpy(mask).cuda()
248
    else:
249
      train_batch = get_batch(train_data, FLAGS.batch_size, FLAGS.seq_len,
250
                              train_index)
251
    train_batch = torch.from_numpy(train_batch).cuda().long()
252

253
    if FLAGS.dataset == 'copy':
254
      if FLAGS.step_size != -1 and (train_index >= FLAGS.iters_count // 2 or
255
                                    not FLAGS.finetune):
256
        train_loss, acc = model.loss_with_grad(
257
            train_batch, FLAGS.step_size, return_acc=True, nonpad_mask=mask)
258
      else:
259
        train_loss, acc = model.full_loss(
260
            train_batch, with_grad=True, return_acc=True, nonpad_mask=mask)
261

262
    else:
263
      if FLAGS.step_size != -1 and (train_index >= FLAGS.iters_count // 2 or
264
                                    not FLAGS.finetune):
265
        train_loss = model.loss_with_grad(train_batch, FLAGS.step_size)
266
      else:
267
        train_loss = model.full_loss(train_batch, with_grad=True)
268

269
    optimizer.step()
270
    optimizer.zero_grad()
271
    if FLAGS.dataset == 'copy':
272
      scheduler.step()
273

274
    train_bpd = train_loss.item() / np.log(2)
275
    gb_in_use = torch.cuda.max_memory_allocated(0) / (1024 * 1024 * 1024)
276

277
    if FLAGS.dataset == 'copy':
278
      if (train_index + 1) % FLAGS.val_step == 0:
279

280
        seconds = time.time() - training_start
281

282
        if (train_index + 1) % FLAGS.print_step == 0:
283
          logging.info(
284
              'iter#{0} sec={1:.1f} bpd={2:.4f} acc={3:.4f} gb={4:.4f}'.format(
285
                  train_index + 1, seconds, train_bpd, acc.item(), gb_in_use))
286

287
    else:
288
      if (train_index + 1) % FLAGS.val_step == 0:
289

290
        model.eval()
291

292
        val_batch = get_batch(val_data, FLAGS.batch_size, FLAGS.seq_len,
293
                              train_index // FLAGS.val_step)
294
        val_batch = torch.from_numpy(val_batch).cuda().long()
295

296
        with torch.no_grad():
297
          val_loss = model.full_loss(val_batch, with_grad=False)
298

299
        val_bpd = val_loss.item() / np.log(2)
300

301
        seconds = time.time() - training_start
302

303
        if (train_index + 1) % FLAGS.print_step == 0:
304
          logging.info(
305
              'iter#{0} sec={1:.1f} t_bpd={2:.4f} v_bpd={3:.4f} gb={4:.4f}'
306
              .format(train_index + 1, seconds, train_bpd, val_bpd, gb_in_use))
307

308

309
if __name__ == '__main__':
310
  app.run(main)
311

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.