pytorch-lightning
341 строка · 11.6 Кб
1from unittest import mock2from unittest.mock import Mock, call3
4import pytest5import torch6from lightning.fabric import Fabric7from lightning.fabric.plugins import Precision8from lightning.fabric.utilities.throughput import (9Throughput,10ThroughputMonitor,11_MonotonicWindow,12get_available_flops,13measure_flops,14)
15
16from tests_fabric.helpers.runif import RunIf17from tests_fabric.test_fabric import BoringModel18
19
20@RunIf(min_torch="2.1")21def test_measure_flops():22with torch.device("meta"):23model = BoringModel()24x = torch.randn(2, 32)25model_fwd = lambda: model(x)26model_loss = lambda y: y.sum()27
28fwd_flops = measure_flops(model, model_fwd)29assert isinstance(fwd_flops, int)30
31fwd_and_bwd_flops = measure_flops(model, model_fwd, model_loss)32assert isinstance(fwd_and_bwd_flops, int)33assert fwd_flops < fwd_and_bwd_flops34
35
36def test_get_available_flops(xla_available):37with mock.patch("torch.cuda.get_device_name", return_value="NVIDIA H100 PCIe"):38flops = get_available_flops(torch.device("cuda"), torch.bfloat16)39assert flops == 756e1240
41with pytest.warns(match="not found for 'CocoNut"), mock.patch("torch.cuda.get_device_name", return_value="CocoNut"):42assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None43
44with pytest.warns(match="t4' does not support torch.bfloat"), mock.patch(45"torch.cuda.get_device_name", return_value="t4"46):47assert get_available_flops(torch.device("cuda"), torch.bfloat16) is None48
49from torch_xla.experimental import tpu50
51assert isinstance(tpu, Mock)52
53tpu.get_tpu_env.return_value = {"TYPE": "V4"}54flops = get_available_flops(torch.device("xla"), torch.bfloat16)55assert flops == 275e1256
57tpu.get_tpu_env.return_value = {"TYPE": "V1"}58with pytest.warns(match="not found for TPU 'V1'"):59assert get_available_flops(torch.device("xla"), torch.bfloat16) is None60
61tpu.get_tpu_env.return_value = {"ACCELERATOR_TYPE": "v3-8"}62flops = get_available_flops(torch.device("xla"), torch.bfloat16)63assert flops == 123e1264
65tpu.reset_mock()66
67
68@pytest.mark.parametrize(69"device_name",70[71# Hopper72"h100-nvl", # TODO: switch with `torch.cuda.get_device_name()` result73"h100-hbm3", # TODO: switch with `torch.cuda.get_device_name()` result74"NVIDIA H100 PCIe",75"h100-hbm2e", # TODO: switch with `torch.cuda.get_device_name()` result76# Ada77"NVIDIA GeForce RTX 4090",78"NVIDIA GeForce RTX 4080",79"Tesla L40",80"NVIDIA L4",81# Ampere82"NVIDIA A100 80GB PCIe",83"NVIDIA A100-SXM4-40GB",84"NVIDIA GeForce RTX 3090",85"NVIDIA GeForce RTX 3090 Ti",86"NVIDIA GeForce RTX 3080",87"NVIDIA GeForce RTX 3080 Ti",88"NVIDIA GeForce RTX 3070",89pytest.param("NVIDIA GeForce RTX 3070 Ti", marks=pytest.mark.xfail(raises=AssertionError)),90pytest.param("NVIDIA GeForce RTX 3060", marks=pytest.mark.xfail(raises=AssertionError)),91pytest.param("NVIDIA GeForce RTX 3060 Ti", marks=pytest.mark.xfail(raises=AssertionError)),92pytest.param("NVIDIA GeForce RTX 3050", marks=pytest.mark.xfail(raises=AssertionError)),93pytest.param("NVIDIA GeForce RTX 3050 Ti", marks=pytest.mark.xfail(raises=AssertionError)),94"NVIDIA A6000",95"NVIDIA A40",96"NVIDIA A10G",97# Turing98"NVIDIA GeForce RTX 2080 SUPER",99"NVIDIA GeForce RTX 2080 Ti",100"NVIDIA GeForce RTX 2080",101"NVIDIA GeForce RTX 2070 Super",102"Quadro RTX 5000 with Max-Q Design",103"Tesla T4",104"TITAN RTX",105# Volta106"Tesla V100-SXm2-32GB",107"Tesla V100-PCIE-32GB",108"Tesla V100S-PCIE-32GB",109],110)
111@mock.patch("lightning.fabric.accelerators.cuda._is_ampere_or_later", return_value=False)112def test_get_available_flops_cuda_mapping_exists(_, device_name):113"""Tests `get_available_flops` against known device names."""114with mock.patch("lightning.fabric.utilities.throughput.torch.cuda.get_device_name", return_value=device_name):115assert get_available_flops(device=torch.device("cuda"), dtype=torch.float32) is not None116
117
118def test_throughput():119# required args only120throughput = Throughput()121throughput.update(time=2.0, batches=1, samples=2)122assert throughput.compute() == {"time": 2.0, "batches": 1, "samples": 2}123
124# different lengths and samples125with pytest.raises(RuntimeError, match="same number of samples"):126throughput.update(time=2.1, batches=2, samples=3, lengths=4)127
128# lengths and samples129throughput = Throughput(window_size=2)130throughput.update(time=2, batches=1, samples=2, lengths=4)131throughput.update(time=2.5, batches=2, samples=4, lengths=8)132assert throughput.compute() == {133"time": 2.5,134"batches": 2,135"samples": 4,136"lengths": 8,137"device/batches_per_sec": 2.0,138"device/samples_per_sec": 4.0,139"device/items_per_sec": 8.0,140}141
142with pytest.raises(ValueError, match="Expected the value to increase"):143throughput.update(time=2.5, batches=3, samples=2, lengths=4)144
145# flops146throughput = Throughput(available_flops=50, window_size=2)147throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)148throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)149assert throughput.compute() == {150"time": 2,151"batches": 2,152"samples": 4,153"lengths": 20,154"device/batches_per_sec": 1.0,155"device/flops_per_sec": 10.0,156"device/items_per_sec": 10.0,157"device/mfu": 0.2,158"device/samples_per_sec": 2.0,159}160
161# flops without available162throughput.available_flops = None163throughput.reset()164throughput.update(time=1, batches=1, samples=2, flops=10, lengths=10)165throughput.update(time=2, batches=2, samples=4, flops=10, lengths=20)166assert throughput.compute() == {167"time": 2,168"batches": 2,169"samples": 4,170"lengths": 20,171"device/batches_per_sec": 1.0,172"device/flops_per_sec": 10.0,173"device/items_per_sec": 10.0,174"device/samples_per_sec": 2.0,175}176
177throughput = Throughput(window_size=2)178with pytest.raises(ValueError, match=r"samples.*to be greater or equal than batches"):179throughput.update(time=0, batches=2, samples=1)180throughput = Throughput(window_size=2)181with pytest.raises(ValueError, match=r"lengths.*to be greater or equal than samples"):182throughput.update(time=0, batches=2, samples=2, lengths=1)183
184
185def mock_train_loop(monitor):186# simulate lit-gpt style loop187total_lengths = 0188total_t0 = 0.0 # fake times189micro_batch_size = 3190for iter_num in range(1, 6):191# forward + backward + step + zero_grad ...192t1 = iter_num + 0.5193total_lengths += 3 * 2194monitor.update(195time=t1 - total_t0,196batches=iter_num,197samples=iter_num * micro_batch_size,198lengths=total_lengths,199flops=10,200)201monitor.compute_and_log()202
203
204def test_throughput_monitor():205logger_mock = Mock()206fabric = Fabric(devices=1, loggers=logger_mock)207with mock.patch("lightning.fabric.utilities.throughput.get_available_flops", return_value=100):208monitor = ThroughputMonitor(fabric, window_size=4, separator="|")209mock_train_loop(monitor)210assert logger_mock.log_metrics.mock_calls == [211call(metrics={"time": 1.5, "batches": 1, "samples": 3, "lengths": 6}, step=0),212call(metrics={"time": 2.5, "batches": 2, "samples": 6, "lengths": 12}, step=1),213call(metrics={"time": 3.5, "batches": 3, "samples": 9, "lengths": 18}, step=2),214call(215metrics={216"time": 4.5,217"batches": 4,218"samples": 12,219"lengths": 24,220"device|batches_per_sec": 1.0,221"device|samples_per_sec": 3.0,222"device|items_per_sec": 6.0,223"device|flops_per_sec": 10.0,224"device|mfu": 0.1,225},226step=3,227),228call(229metrics={230"time": 5.5,231"batches": 5,232"samples": 15,233"lengths": 30,234"device|batches_per_sec": 1.0,235"device|samples_per_sec": 3.0,236"device|items_per_sec": 6.0,237"device|flops_per_sec": 10.0,238"device|mfu": 0.1,239},240step=4,241),242]243
244
245def test_throughput_monitor_step():246fabric_mock = Mock()247fabric_mock.world_size = 1248fabric_mock.strategy.precision = Precision()249monitor = ThroughputMonitor(fabric_mock)250
251# automatic step increase252assert monitor.step == -1253monitor.update(time=0.5, batches=1, samples=3)254metrics = monitor.compute_and_log()255assert metrics == {"time": 0.5, "batches": 1, "samples": 3}256assert monitor.step == 0257
258# manual step259monitor.update(time=1.5, batches=2, samples=4)260metrics = monitor.compute_and_log(step=5)261assert metrics == {"time": 1.5, "batches": 2, "samples": 4}262assert monitor.step == 5263assert fabric_mock.log_dict.mock_calls == [264call(metrics={"time": 0.5, "batches": 1, "samples": 3}, step=0),265call(metrics={"time": 1.5, "batches": 2, "samples": 4}, step=5),266]267
268
269def test_throughput_monitor_world_size():270logger_mock = Mock()271fabric = Fabric(devices=1, loggers=logger_mock)272with mock.patch("lightning.fabric.utilities.throughput.get_available_flops", return_value=100):273monitor = ThroughputMonitor(fabric, window_size=4)274# simulate that there are 2 devices275monitor.world_size = 2276mock_train_loop(monitor)277assert logger_mock.log_metrics.mock_calls == [278call(metrics={"time": 1.5, "batches": 1, "samples": 3, "lengths": 6}, step=0),279call(metrics={"time": 2.5, "batches": 2, "samples": 6, "lengths": 12}, step=1),280call(metrics={"time": 3.5, "batches": 3, "samples": 9, "lengths": 18}, step=2),281call(282metrics={283"time": 4.5,284"batches": 4,285"samples": 12,286"lengths": 24,287"device/batches_per_sec": 1.0,288"device/samples_per_sec": 3.0,289"batches_per_sec": 2.0,290"samples_per_sec": 6.0,291"items_per_sec": 12.0,292"device/items_per_sec": 6.0,293"flops_per_sec": 20.0,294"device/flops_per_sec": 10.0,295"device/mfu": 0.1,296},297step=3,298),299call(300metrics={301"time": 5.5,302"batches": 5,303"samples": 15,304"lengths": 30,305"device/batches_per_sec": 1.0,306"device/samples_per_sec": 3.0,307"batches_per_sec": 2.0,308"samples_per_sec": 6.0,309"items_per_sec": 12.0,310"device/items_per_sec": 6.0,311"flops_per_sec": 20.0,312"device/flops_per_sec": 10.0,313"device/mfu": 0.1,314},315step=4,316),317]318
319
320def test_monotonic_window():321w = _MonotonicWindow(maxlen=3)322assert w == []323assert len(w) == 0324
325w.append(1)326w.append(2)327w.append(3)328assert w == [1, 2, 3]329assert len(w) == 3330assert w[1] == 2331assert w[-2:] == [2, 3]332
333with pytest.raises(NotImplementedError):334w[1] = 123335with pytest.raises(NotImplementedError):336w[1:2] = [1, 2]337
338with pytest.raises(ValueError, match="Expected the value to increase"):339w.append(2)340w.clear()341w.append(2)342