transformers
186 строк · 6.7 Кб
1# coding=utf-8
2# Copyright 2020 The HuggingFace Team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16
17import os
18import tempfile
19import unittest
20
21from transformers import is_torch_available
22from transformers.testing_utils import require_torch
23
24
25if is_torch_available():
26import torch
27from torch import nn
28
29from transformers import (
30Adafactor,
31AdamW,
32get_constant_schedule,
33get_constant_schedule_with_warmup,
34get_cosine_schedule_with_warmup,
35get_cosine_with_hard_restarts_schedule_with_warmup,
36get_inverse_sqrt_schedule,
37get_linear_schedule_with_warmup,
38get_polynomial_decay_schedule_with_warmup,
39)
40
41
42def unwrap_schedule(scheduler, num_steps=10):
43lrs = []
44for _ in range(num_steps):
45lrs.append(scheduler.get_lr()[0])
46scheduler.step()
47return lrs
48
49
50def unwrap_and_save_reload_schedule(scheduler, num_steps=10):
51lrs = []
52for step in range(num_steps):
53lrs.append(scheduler.get_lr()[0])
54scheduler.step()
55if step == num_steps // 2:
56with tempfile.TemporaryDirectory() as tmpdirname:
57file_name = os.path.join(tmpdirname, "schedule.bin")
58torch.save(scheduler.state_dict(), file_name)
59
60state_dict = torch.load(file_name)
61scheduler.load_state_dict(state_dict)
62return lrs
63
64
65@require_torch
66class OptimizationTest(unittest.TestCase):
67def assertListAlmostEqual(self, list1, list2, tol):
68self.assertEqual(len(list1), len(list2))
69for a, b in zip(list1, list2):
70self.assertAlmostEqual(a, b, delta=tol)
71
72def test_adam_w(self):
73w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
74target = torch.tensor([0.4, 0.2, -0.5])
75criterion = nn.MSELoss()
76# No warmup, constant schedule, no gradient clipping
77optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
78for _ in range(100):
79loss = criterion(w, target)
80loss.backward()
81optimizer.step()
82w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
83w.grad.zero_()
84self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
85
86def test_adafactor(self):
87w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
88target = torch.tensor([0.4, 0.2, -0.5])
89criterion = nn.MSELoss()
90# No warmup, constant schedule, no gradient clipping
91optimizer = Adafactor(
92params=[w],
93lr=1e-2,
94eps=(1e-30, 1e-3),
95clip_threshold=1.0,
96decay_rate=-0.8,
97beta1=None,
98weight_decay=0.0,
99relative_step=False,
100scale_parameter=False,
101warmup_init=False,
102)
103for _ in range(1000):
104loss = criterion(w, target)
105loss.backward()
106optimizer.step()
107w.grad.detach_() # No zero_grad() function on simple tensors. we do it ourselves.
108w.grad.zero_()
109self.assertListAlmostEqual(w.tolist(), [0.4, 0.2, -0.5], tol=1e-2)
110
111
112@require_torch
113class ScheduleInitTest(unittest.TestCase):
114m = nn.Linear(50, 50) if is_torch_available() else None
115optimizer = AdamW(m.parameters(), lr=10.0) if is_torch_available() else None
116num_steps = 10
117
118def assertListAlmostEqual(self, list1, list2, tol, msg=None):
119self.assertEqual(len(list1), len(list2))
120for a, b in zip(list1, list2):
121self.assertAlmostEqual(a, b, delta=tol, msg=msg)
122
123def test_schedulers(self):
124common_kwargs = {"num_warmup_steps": 2, "num_training_steps": 10}
125# schedulers doct format
126# function: (sched_args_dict, expected_learning_rates)
127scheds = {
128get_constant_schedule: ({}, [10.0] * self.num_steps),
129get_constant_schedule_with_warmup: (
130{"num_warmup_steps": 4},
131[0.0, 2.5, 5.0, 7.5, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0],
132),
133get_linear_schedule_with_warmup: (
134{**common_kwargs},
135[0.0, 5.0, 10.0, 8.75, 7.5, 6.25, 5.0, 3.75, 2.5, 1.25],
136),
137get_cosine_schedule_with_warmup: (
138{**common_kwargs},
139[0.0, 5.0, 10.0, 9.61, 8.53, 6.91, 5.0, 3.08, 1.46, 0.38],
140),
141get_cosine_with_hard_restarts_schedule_with_warmup: (
142{**common_kwargs, "num_cycles": 2},
143[0.0, 5.0, 10.0, 8.53, 5.0, 1.46, 10.0, 8.53, 5.0, 1.46],
144),
145get_polynomial_decay_schedule_with_warmup: (
146{**common_kwargs, "power": 2.0, "lr_end": 1e-7},
147[0.0, 5.0, 10.0, 7.656, 5.625, 3.906, 2.5, 1.406, 0.625, 0.156],
148),
149get_inverse_sqrt_schedule: (
150{"num_warmup_steps": 2},
151[0.0, 5.0, 10.0, 8.165, 7.071, 6.325, 5.774, 5.345, 5.0, 4.714],
152),
153}
154
155for scheduler_func, data in scheds.items():
156kwargs, expected_learning_rates = data
157
158scheduler = scheduler_func(self.optimizer, **kwargs)
159self.assertEqual(len([scheduler.get_lr()[0]]), 1)
160lrs_1 = unwrap_schedule(scheduler, self.num_steps)
161self.assertListAlmostEqual(
162lrs_1,
163expected_learning_rates,
164tol=1e-2,
165msg=f"failed for {scheduler_func} in normal scheduler",
166)
167
168scheduler = scheduler_func(self.optimizer, **kwargs)
169if scheduler_func.__name__ != "get_constant_schedule":
170LambdaScheduleWrapper.wrap_scheduler(scheduler) # wrap to test picklability of the schedule
171lrs_2 = unwrap_and_save_reload_schedule(scheduler, self.num_steps)
172self.assertListEqual(lrs_1, lrs_2, msg=f"failed for {scheduler_func} in save and reload")
173
174
175class LambdaScheduleWrapper:
176"""See https://github.com/huggingface/transformers/issues/21689"""
177
178def __init__(self, fn):
179self.fn = fn
180
181def __call__(self, *args, **kwargs):
182return self.fn(*args, **kwargs)
183
184@classmethod
185def wrap_scheduler(self, scheduler):
186scheduler.lr_lambdas = list(map(self, scheduler.lr_lambdas))
187