google-research

Форк
0
102 строки · 3.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Common code from different mains."""
17

18
import jax.numpy as jnp
19
import numpy as np
20

21

22
STEPS_PER_EPOCH = 4500
23

24

25
def create_learning_rate_scheduler(
26
    factors='constant * linear_warmup * rsqrt_decay',
27
    base_learning_rate=0.5,
28
    warmup_steps=1000,
29
    decay_factor=0.5,
30
    steps_per_decay=20000,
31
    steps_per_cycle=100000,
32
    init_step=0,
33
    finetune_lr=False):
34
  """Creates learning rate schedule.
35

36
  Interprets factors in the factors string which can consist of:
37
  * constant: interpreted as the constant value,
38
  * linear_warmup: interpreted as linear warmup until warmup_steps,
39
  * rsqrt_decay: divide by square root of max(step, warmup_steps)
40
  * rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
41
  * decay_every: Every k steps decay the learning rate by decay_factor.
42
  * cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
43

44
  Args:
45
    factors: string, factors separated by "*" that defines the schedule.
46
    base_learning_rate: float, the starting constant for the lr schedule.
47
    warmup_steps: int, how many steps to warm up for in the warmup schedule.
48
    decay_factor: float, the amount to decay the learning rate by.
49
    steps_per_decay: int, how often to decay the learning rate.
50
    steps_per_cycle: int, steps per cycle when using cosine decay.
51
    init_step: int, first step of this run. Used with finetune_lr
52
    finetune_lr: bool, modify step count for finetuning smaller datasets
53

54
  Returns:
55
    a function learning_rate(step): float -> {"learning_rate": float}, the
56
    step-dependent lr.
57
  """
58
  factors = [n.strip() for n in factors.split('*')]
59

60
  def step_fn(step):
61
    """Step to learning rate function."""
62
    ret = 1.0
63
    if finetune_lr:
64
      steps_this_run = step - init_step
65
      multiplier = STEPS_PER_EPOCH / steps_per_cycle
66
      finetune_steps = steps_this_run * multiplier
67
      step = init_step + finetune_steps
68

69
    for name in factors:
70
      if name == 'constant':
71
        ret *= base_learning_rate
72
      elif name == 'linear_warmup':
73
        ret *= jnp.minimum(1.0, step / warmup_steps)
74
      elif name == 'rsqrt_decay':
75
        ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
76
      elif name == 'rsqrt_normalized_decay':
77
        ret *= jnp.sqrt(warmup_steps)
78
        ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
79
      elif name == 'decay_every':
80
        ret *= (decay_factor**(step // steps_per_decay))
81
      elif name == 'cosine_decay':
82
        progress = jnp.maximum(0.0,
83
                               (step - warmup_steps) / float(steps_per_cycle))
84
        ret *= jnp.maximum(0.0,
85
                           0.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
86
      else:
87
        raise ValueError('Unknown factor %s.' % name)
88
    return jnp.asarray(ret, dtype=jnp.float32)
89

90
  return step_fn
91

92

93
def pad_examples(x, desired_batch_size):
94
  """Expand batch to desired size by repeating last slice."""
95
  batch_pad = desired_batch_size - x.shape[0]
96
  return np.concatenate([x, np.tile(x[-1], (batch_pad, 1))], axis=0)
97

98

99
def tohost(x):
100
  """Collect batches from all devices to host and flatten batch dimensions."""
101
  n_device, n_batch, *remaining_dims = x.shape
102
  return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims))
103

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.