pytorch-lightning

Форк
0
341 строка · 11.6 Кб
1
from unittest import mock
2
from unittest.mock import Mock, call
3

4
import pytest
5
import torch
6
from lightning.fabric import Fabric
7
from lightning.fabric.plugins import Precision
8
from lightning.fabric.utilities.throughput import (
9
    Throughput,
10
    ThroughputMonitor,
11
    _MonotonicWindow,
12
    get_available_flops,
13
    measure_flops,
14
)
15

16
from tests_fabric.helpers.runif import RunIf
17
from tests_fabric.test_fabric import BoringModel
18

19

20
@RunIf(min_torch="2.1")
21
def test_measure_flops():
22
    with torch.device("meta"):
23
        model = BoringModel()
24
        x = torch.randn(2, 32)
25
    model_fwd = lambda: model(x)
26
    model_loss = lambda y: y.sum()
27

28
    fwd_flops = measure_flops(model, model_fwd)
29
    assert isinstance(fwd_flops, int)
30

31
    fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)
32
    assert isinstance(fwd_and_bwd_flops, int)
33
    assert fwd_flops < fwd_and_bwd_flops
34

35

36
def test_get_available_flops(xla_available):
37
    with mock.patch("torch.cuda.get_device_name", return_value="NVIDIA H100 PCIe"):
38
        flops = get_available_flops(torch.device("cuda"), torch.bfloat16)
39
    assert flops == 756e12
40

41
    with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"):
42
        assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
43

44
    with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch(
45
        "torch.cuda.get_device_name", return_value="t4"
46
    ):
47
        assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None
48

49
    from torch_xla.experimental import tpu
50

51
    assert isinstance(tpu, Mock)
52

53
    tpu.get_tpu_env.return_value = {"TYPE": "V4"}
54
    flops = get_available_flops(torch.device("xla"), torch.bfloat16)
55
    assert flops == 275e12
56

57
    tpu.get_tpu_env.return_value = {"TYPE": "V1"}
58
    with pytest.warns(match="not found for TPU 'V1'"):
59
        assert get_available_flops(torch.device("xla"), torch.bfloat16) is None
60

61
    tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"}
62
    flops = get_available_flops(torch.device("xla"), torch.bfloat16)
63
    assert flops == 123e12
64

65
    tpu.reset_mock()
66

67

68
@pytest.mark.parametrize(
69
    "device_name",
70
    [
71
        # Hopper
72
        "h100-nvl",  # TODO: switch with `torch.cuda.get_device_name()` result
73
        "h100-hbm3",  # TODO: switch with `torch.cuda.get_device_name()` result
74
        "NVIDIA H100 PCIe",
75
        "h100-hbm2e",  # TODO: switch with `torch.cuda.get_device_name()` result
76
        # Ada
77
        "NVIDIA GeForce RTX 4090",
78
        "NVIDIA GeForce RTX 4080",
79
        "Tesla L40",
80
        "NVIDIA L4",
81
        # Ampere
82
        "NVIDIA A100 80GB PCIe",
83
        "NVIDIA A100-SXM4-40GB",
84
        "NVIDIA GeForce RTX 3090",
85
        "NVIDIA GeForce RTX 3090 Ti",
86
        "NVIDIA GeForce RTX 3080",
87
        "NVIDIA GeForce RTX 3080 Ti",
88
        "NVIDIA GeForce RTX 3070",
89
        pytest.param("NVIDIA GeForce RTX 3070 Ti", marks=pytest.mark.xfail(raises=AssertionError)),
90
        pytest.param("NVIDIA GeForce RTX 3060", marks=pytest.mark.xfail(raises=AssertionError)),
91
        pytest.param("NVIDIA GeForce RTX 3060 Ti", marks=pytest.mark.xfail(raises=AssertionError)),
92
        pytest.param("NVIDIA GeForce RTX 3050", marks=pytest.mark.xfail(raises=AssertionError)),
93
        pytest.param("NVIDIA GeForce RTX 3050 Ti", marks=pytest.mark.xfail(raises=AssertionError)),
94
        "NVIDIA A6000",
95
        "NVIDIA A40",
96
        "NVIDIA A10G",
97
        # Turing
98
        "NVIDIA GeForce RTX 2080 SUPER",
99
        "NVIDIA GeForce RTX 2080 Ti",
100
        "NVIDIA GeForce RTX 2080",
101
        "NVIDIA GeForce RTX 2070 Super",
102
        "Quadro RTX 5000 with Max-Q Design",
103
        "Tesla T4",
104
        "TITAN RTX",
105
        # Volta
106
        "Tesla V100-SXm2-32GB",
107
        "Tesla V100-PCIE-32GB",
108
        "Tesla V100S-PCIE-32GB",
109
    ],
110
)
111
@mock.patch("lightning.fabric.accelerators.cuda._is_ampere_or_later", return_value=False)
112
def test_get_available_flops_cuda_mapping_exists(_, device_name):
113
    """Tests `get_available_flops` against known device names."""
114
    with mock.patch("lightning.fabric.utilities.throughput.torch.cuda.get_device_name", return_value=device_name):
115
        assert get_available_flops(device=torch.device("cuda"), dtype=torch.float32) is not None
116

117

118
def test_throughput():
119
    # required args only
120
    throughput = Throughput()
121
    throughput.update(time=2.0, batches=1, samples=2)
122
    assert throughput.compute() == {"time": 2.0, "batches": 1, "samples": 2}
123

124
    # different lengths and samples
125
    with pytest.raises(RuntimeError, match="same number of samples"):
126
        throughput.update(time=2.1, batches=2, samples=3, lengths=4)
127

128
    # lengths and samples
129
    throughput = Throughput(window_size=2)
130
    throughput.update(time=2, batches=1, samples=2, lengths=4)
131
    throughput.update(time=2.5, batches=2, samples=4, lengths=8)
132
    assert throughput.compute() == {
133
        "time": 2.5,
134
        "batches": 2,
135
        "samples": 4,
136
        "lengths": 8,
137
        "device/batches_per_sec": 2.0,
138
        "device/samples_per_sec": 4.0,
139
        "device/items_per_sec": 8.0,
140
    }
141

142
    with pytest.raises(ValueError, match="Expected the value to increase"):
143
        throughput.update(time=2.5, batches=3, samples=2, lengths=4)
144

145
    # flops
146
    throughput = Throughput(available_flops=50, window_size=2)
147
    throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
148
    throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
149
    assert throughput.compute() == {
150
        "time": 2,
151
        "batches": 2,
152
        "samples": 4,
153
        "lengths": 20,
154
        "device/batches_per_sec": 1.0,
155
        "device/flops_per_sec": 10.0,
156
        "device/items_per_sec": 10.0,
157
        "device/mfu": 0.2,
158
        "device/samples_per_sec": 2.0,
159
    }
160

161
    # flops without available
162
    throughput.available_flops = None
163
    throughput.reset()
164
    throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)
165
    throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)
166
    assert throughput.compute() == {
167
        "time": 2,
168
        "batches": 2,
169
        "samples": 4,
170
        "lengths": 20,
171
        "device/batches_per_sec": 1.0,
172
        "device/flops_per_sec": 10.0,
173
        "device/items_per_sec": 10.0,
174
        "device/samples_per_sec": 2.0,
175
    }
176

177
    throughput = Throughput(window_size=2)
178
    with pytest.raises(ValueError, match=r"samples.*to be greater or equal than batches"):
179
        throughput.update(time=0, batches=2, samples=1)
180
    throughput = Throughput(window_size=2)
181
    with pytest.raises(ValueError, match=r"lengths.*to be greater or equal than samples"):
182
        throughput.update(time=0, batches=2, samples=2, lengths=1)
183

184

185
def mock_train_loop(monitor):
186
    # simulate lit-gpt style loop
187
    total_lengths = 0
188
    total_t0 = 0.0  # fake times
189
    micro_batch_size = 3
190
    for iter_num in range(1, 6):
191
        # forward + backward + step + zero_grad ...
192
        t1 = iter_num + 0.5
193
        total_lengths += 3 * 2
194
        monitor.update(
195
            time=t1 - total_t0,
196
            batches=iter_num,
197
            samples=iter_num * micro_batch_size,
198
            lengths=total_lengths,
199
            flops=10,
200
        )
201
        monitor.compute_and_log()
202

203

204
def test_throughput_monitor():
205
    logger_mock = Mock()
206
    fabric = Fabric(devices=1, loggers=logger_mock)
207
    with mock.patch("lightning.fabric.utilities.throughput.get_available_flops", return_value=100):
208
        monitor = ThroughputMonitor(fabric, window_size=4, separator="|")
209
    mock_train_loop(monitor)
210
    assert logger_mock.log_metrics.mock_calls == [
211
        call(metrics={"time": 1.5, "batches": 1, "samples": 3, "lengths": 6}, step=0),
212
        call(metrics={"time": 2.5, "batches": 2, "samples": 6, "lengths": 12}, step=1),
213
        call(metrics={"time": 3.5, "batches": 3, "samples": 9, "lengths": 18}, step=2),
214
        call(
215
            metrics={
216
                "time": 4.5,
217
                "batches": 4,
218
                "samples": 12,
219
                "lengths": 24,
220
                "device|batches_per_sec": 1.0,
221
                "device|samples_per_sec": 3.0,
222
                "device|items_per_sec": 6.0,
223
                "device|flops_per_sec": 10.0,
224
                "device|mfu": 0.1,
225
            },
226
            step=3,
227
        ),
228
        call(
229
            metrics={
230
                "time": 5.5,
231
                "batches": 5,
232
                "samples": 15,
233
                "lengths": 30,
234
                "device|batches_per_sec": 1.0,
235
                "device|samples_per_sec": 3.0,
236
                "device|items_per_sec": 6.0,
237
                "device|flops_per_sec": 10.0,
238
                "device|mfu": 0.1,
239
            },
240
            step=4,
241
        ),
242
    ]
243

244

245
def test_throughput_monitor_step():
246
    fabric_mock = Mock()
247
    fabric_mock.world_size = 1
248
    fabric_mock.strategy.precision = Precision()
249
    monitor = ThroughputMonitor(fabric_mock)
250

251
    # automatic step increase
252
    assert monitor.step == -1
253
    monitor.update(time=0.5, batches=1, samples=3)
254
    metrics = monitor.compute_and_log()
255
    assert metrics == {"time": 0.5, "batches": 1, "samples": 3}
256
    assert monitor.step == 0
257

258
    # manual step
259
    monitor.update(time=1.5, batches=2, samples=4)
260
    metrics = monitor.compute_and_log(step=5)
261
    assert metrics == {"time": 1.5, "batches": 2, "samples": 4}
262
    assert monitor.step == 5
263
    assert fabric_mock.log_dict.mock_calls == [
264
        call(metrics={"time": 0.5, "batches": 1, "samples": 3}, step=0),
265
        call(metrics={"time": 1.5, "batches": 2, "samples": 4}, step=5),
266
    ]
267

268

269
def test_throughput_monitor_world_size():
270
    logger_mock = Mock()
271
    fabric = Fabric(devices=1, loggers=logger_mock)
272
    with mock.patch("lightning.fabric.utilities.throughput.get_available_flops", return_value=100):
273
        monitor = ThroughputMonitor(fabric, window_size=4)
274
        # simulate that there are 2 devices
275
        monitor.world_size = 2
276
    mock_train_loop(monitor)
277
    assert logger_mock.log_metrics.mock_calls == [
278
        call(metrics={"time": 1.5, "batches": 1, "samples": 3, "lengths": 6}, step=0),
279
        call(metrics={"time": 2.5, "batches": 2, "samples": 6, "lengths": 12}, step=1),
280
        call(metrics={"time": 3.5, "batches": 3, "samples": 9, "lengths": 18}, step=2),
281
        call(
282
            metrics={
283
                "time": 4.5,
284
                "batches": 4,
285
                "samples": 12,
286
                "lengths": 24,
287
                "device/batches_per_sec": 1.0,
288
                "device/samples_per_sec": 3.0,
289
                "batches_per_sec": 2.0,
290
                "samples_per_sec": 6.0,
291
                "items_per_sec": 12.0,
292
                "device/items_per_sec": 6.0,
293
                "flops_per_sec": 20.0,
294
                "device/flops_per_sec": 10.0,
295
                "device/mfu": 0.1,
296
            },
297
            step=3,
298
        ),
299
        call(
300
            metrics={
301
                "time": 5.5,
302
                "batches": 5,
303
                "samples": 15,
304
                "lengths": 30,
305
                "device/batches_per_sec": 1.0,
306
                "device/samples_per_sec": 3.0,
307
                "batches_per_sec": 2.0,
308
                "samples_per_sec": 6.0,
309
                "items_per_sec": 12.0,
310
                "device/items_per_sec": 6.0,
311
                "flops_per_sec": 20.0,
312
                "device/flops_per_sec": 10.0,
313
                "device/mfu": 0.1,
314
            },
315
            step=4,
316
        ),
317
    ]
318

319

320
def test_monotonic_window():
321
    w = _MonotonicWindow(maxlen=3)
322
    assert w == []
323
    assert len(w) == 0
324

325
    w.append(1)
326
    w.append(2)
327
    w.append(3)
328
    assert w == [1, 2, 3]
329
    assert len(w) == 3
330
    assert w[1] == 2
331
    assert w[-2:] == [2, 3]
332

333
    with pytest.raises(NotImplementedError):
334
        w[1] = 123
335
    with pytest.raises(NotImplementedError):
336
        w[1:2] = [1, 2]
337

338
    with pytest.raises(ValueError, match="Expected the value to increase"):
339
        w.append(2)
340
    w.clear()
341
    w.append(2)
342

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.