pytorch-lightning
241 строка · 8.9 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import logging
15import os
16from argparse import Namespace
17from unittest import mock
18from unittest.mock import Mock
19
20import numpy as np
21import pytest
22import torch
23from lightning.fabric.loggers import TensorBoardLogger
24from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
25from lightning.fabric.wrappers import _FabricModule
26
27from tests_fabric.test_fabric import BoringModel
28
29
30def test_tensorboard_automatic_versioning(tmp_path):
31"""Verify that automatic versioning works."""
32root_dir = tmp_path / "tb_versioning"
33root_dir.mkdir()
34(root_dir / "version_0").mkdir()
35(root_dir / "version_1").mkdir()
36(root_dir / "version_nonumber").mkdir()
37(root_dir / "other").mkdir()
38
39logger = TensorBoardLogger(root_dir=tmp_path, name="tb_versioning")
40assert logger.version == 2
41
42
43def test_tensorboard_manual_versioning(tmp_path):
44"""Verify that manual versioning works."""
45root_dir = tmp_path / "tb_versioning"
46root_dir.mkdir()
47(root_dir / "version_0").mkdir()
48(root_dir / "version_1").mkdir()
49(root_dir / "version_2").mkdir()
50
51logger = TensorBoardLogger(root_dir=tmp_path, name="tb_versioning", version=1)
52assert logger.version == 1
53
54
55def test_tensorboard_named_version(tmp_path):
56"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
57name = "tb_versioning"
58(tmp_path / name).mkdir()
59expected_version = "2020-02-05-162402"
60
61logger = TensorBoardLogger(root_dir=tmp_path, name=name, version=expected_version)
62logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
63
64assert logger.version == expected_version
65assert os.listdir(tmp_path / name) == [expected_version]
66assert os.listdir(tmp_path / name / expected_version)
67
68
69@pytest.mark.parametrize("name", ["", None])
70def test_tensorboard_no_name(tmp_path, name):
71"""Verify that None or empty name works."""
72logger = TensorBoardLogger(root_dir=tmp_path, name=name)
73logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5}) # Force data to be written
74assert os.path.normpath(logger.root_dir) == str(tmp_path) # use os.path.normpath to handle trailing /
75assert os.listdir(tmp_path / "version_0")
76
77
78def test_tensorboard_log_sub_dir(tmp_path):
79# no sub_dir specified
80root_dir = tmp_path / "logs"
81logger = TensorBoardLogger(root_dir, name="name", version="version")
82assert logger.log_dir == os.path.join(root_dir, "name", "version")
83
84# sub_dir specified
85logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
86assert logger.log_dir == os.path.join(root_dir, "name", "version", "sub_dir")
87
88
89def test_tensorboard_expand_home():
90"""Test that the home dir (`~`) gets expanded properly."""
91root_dir = "~/tmp"
92explicit_root_dir = os.path.expanduser(root_dir)
93logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
94assert logger.root_dir == root_dir
95assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")
96
97
98@mock.patch.dict(os.environ, {"TEST_ENV_DIR": "some_directory"})
99def test_tensorboard_expand_env_vars():
100"""Test that the env vars in path names (`$`) get handled properly."""
101test_env_dir = os.environ["TEST_ENV_DIR"]
102root_dir = "$TEST_ENV_DIR/tmp"
103explicit_root_dir = f"{test_env_dir}/tmp"
104logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
105assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")
106
107
108@pytest.mark.parametrize("step_idx", [10, None])
109def test_tensorboard_log_metrics(tmp_path, step_idx):
110logger = TensorBoardLogger(tmp_path)
111metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
112logger.log_metrics(metrics, step_idx)
113
114
115def test_tensorboard_log_hyperparams(tmp_path):
116logger = TensorBoardLogger(tmp_path)
117hparams = {
118"float": 0.3,
119"int": 1,
120"string": "abc",
121"bool": True,
122"dict": {"a": {"b": "c"}},
123"list": [1, 2, 3],
124"namespace": Namespace(foo=Namespace(bar="buzz")),
125"layer": torch.nn.BatchNorm1d,
126"tensor": torch.empty(2, 2, 2),
127"array": np.empty([2, 2, 2]),
128}
129logger.log_hyperparams(hparams)
130
131
132def test_tensorboard_log_hparams_and_metrics(tmp_path):
133logger = TensorBoardLogger(tmp_path, default_hp_metric=False)
134hparams = {
135"float": 0.3,
136"int": 1,
137"string": "abc",
138"bool": True,
139"dict": {"a": {"b": "c"}},
140"list": [1, 2, 3],
141"namespace": Namespace(foo=Namespace(bar="buzz")),
142"layer": torch.nn.BatchNorm1d,
143"tensor": torch.empty(2, 2, 2),
144"array": np.empty([2, 2, 2]),
145}
146metrics = {"abc": torch.tensor([0.54])}
147logger.log_hyperparams(hparams, metrics)
148
149
150@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
151def test_tensorboard_log_graph(tmp_path, example_input_array):
152"""Test that log graph works with both model.example_input_array and if array is passed externally."""
153# TODO(fabric): Test both nn.Module and LightningModule
154# TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
155model = BoringModel()
156if example_input_array is not None:
157model.example_input_array = None
158
159logger = TensorBoardLogger(tmp_path)
160logger._experiment = Mock()
161logger.log_graph(model, example_input_array)
162if example_input_array is not None:
163logger.experiment.add_graph.assert_called_with(model, example_input_array)
164logger._experiment.reset_mock()
165
166# model wrapped in `FabricModule`
167wrapped = _FabricModule(model, strategy=Mock())
168logger.log_graph(wrapped, example_input_array)
169if example_input_array is not None:
170logger.experiment.add_graph.assert_called_with(model, example_input_array)
171
172
173@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
174def test_tensorboard_log_graph_warning_no_example_input_array(tmp_path):
175"""Test that log graph throws warning if model.example_input_array is None."""
176model = BoringModel()
177model.example_input_array = None
178logger = TensorBoardLogger(tmp_path, log_graph=True)
179with pytest.warns(
180UserWarning,
181match="Could not log computational graph to TensorBoard: The `model.example_input_array` .* was not given",
182):
183logger.log_graph(model)
184
185model.example_input_array = {"x": 1, "y": 2}
186with pytest.warns(
187UserWarning, match="Could not log computational graph to TensorBoard: .* can't be traced by TensorBoard"
188):
189logger.log_graph(model)
190
191
192def test_tensorboard_finalize(monkeypatch, tmp_path):
193"""Test that the SummaryWriter closes in finalize."""
194if _TENSORBOARD_AVAILABLE:
195import torch.utils.tensorboard as tb
196else:
197import tensorboardX as tb
198
199monkeypatch.setattr(tb, "SummaryWriter", Mock())
200logger = TensorBoardLogger(root_dir=tmp_path)
201assert logger._experiment is None
202logger.finalize("any")
203
204# no log calls, no experiment created -> nothing to flush
205logger.experiment.assert_not_called()
206
207logger = TensorBoardLogger(root_dir=tmp_path)
208logger.log_metrics({"flush_me": 11.1}) # trigger creation of an experiment
209logger.finalize("any")
210
211# finalize flushes to experiment directory
212logger.experiment.flush.assert_called()
213logger.experiment.close.assert_called()
214
215
216@mock.patch("lightning.fabric.loggers.tensorboard.log")
217def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
218"""Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
219relative paths."""
220monkeypatch.chdir(tmp_path) # need to use relative paths
221source = os.path.join(".", "lightning_logs")
222dest = os.path.join(".", "sym_lightning_logs")
223
224os.makedirs(source, exist_ok=True)
225os.symlink(source, dest)
226
227logger = TensorBoardLogger(root_dir=dest, name="")
228_ = logger.version
229
230log.warning.assert_not_called()
231
232
233def test_tensorboard_missing_folder_warning(tmp_path, caplog):
234"""Verify that the logger throws a warning for invalid directory."""
235name = "fake_dir"
236logger = TensorBoardLogger(root_dir=tmp_path, name=name)
237
238with caplog.at_level(logging.WARNING):
239assert logger.version == 0
240
241assert "Missing logger folder:" in caplog.text
242