google-research
102 строки · 3.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Common code from different mains."""
17
18import jax.numpy as jnp
19import numpy as np
20
21
22STEPS_PER_EPOCH = 4500
23
24
25def create_learning_rate_scheduler(
26factors='constant * linear_warmup * rsqrt_decay',
27base_learning_rate=0.5,
28warmup_steps=1000,
29decay_factor=0.5,
30steps_per_decay=20000,
31steps_per_cycle=100000,
32init_step=0,
33finetune_lr=False):
34"""Creates learning rate schedule.
35
36Interprets factors in the factors string which can consist of:
37* constant: interpreted as the constant value,
38* linear_warmup: interpreted as linear warmup until warmup_steps,
39* rsqrt_decay: divide by square root of max(step, warmup_steps)
40* rsqrt_normalized_decay: divide by square root of max(step/warmup_steps, 1)
41* decay_every: Every k steps decay the learning rate by decay_factor.
42* cosine_decay: Cyclic cosine decay, uses steps_per_cycle parameter.
43
44Args:
45factors: string, factors separated by "*" that defines the schedule.
46base_learning_rate: float, the starting constant for the lr schedule.
47warmup_steps: int, how many steps to warm up for in the warmup schedule.
48decay_factor: float, the amount to decay the learning rate by.
49steps_per_decay: int, how often to decay the learning rate.
50steps_per_cycle: int, steps per cycle when using cosine decay.
51init_step: int, first step of this run. Used with finetune_lr
52finetune_lr: bool, modify step count for finetuning smaller datasets
53
54Returns:
55a function learning_rate(step): float -> {"learning_rate": float}, the
56step-dependent lr.
57"""
58factors = [n.strip() for n in factors.split('*')]
59
60def step_fn(step):
61"""Step to learning rate function."""
62ret = 1.0
63if finetune_lr:
64steps_this_run = step - init_step
65multiplier = STEPS_PER_EPOCH / steps_per_cycle
66finetune_steps = steps_this_run * multiplier
67step = init_step + finetune_steps
68
69for name in factors:
70if name == 'constant':
71ret *= base_learning_rate
72elif name == 'linear_warmup':
73ret *= jnp.minimum(1.0, step / warmup_steps)
74elif name == 'rsqrt_decay':
75ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
76elif name == 'rsqrt_normalized_decay':
77ret *= jnp.sqrt(warmup_steps)
78ret /= jnp.sqrt(jnp.maximum(step, warmup_steps))
79elif name == 'decay_every':
80ret *= (decay_factor**(step // steps_per_decay))
81elif name == 'cosine_decay':
82progress = jnp.maximum(0.0,
83(step - warmup_steps) / float(steps_per_cycle))
84ret *= jnp.maximum(0.0,
850.5 * (1.0 + jnp.cos(jnp.pi * (progress % 1.0))))
86else:
87raise ValueError('Unknown factor %s.' % name)
88return jnp.asarray(ret, dtype=jnp.float32)
89
90return step_fn
91
92
93def pad_examples(x, desired_batch_size):
94"""Expand batch to desired size by repeating last slice."""
95batch_pad = desired_batch_size - x.shape[0]
96return np.concatenate([x, np.tile(x[-1], (batch_pad, 1))], axis=0)
97
98
99def tohost(x):
100"""Collect batches from all devices to host and flatten batch dimensions."""
101n_device, n_batch, *remaining_dims = x.shape
102return np.array(x).reshape((n_device * n_batch,) + tuple(remaining_dims))
103