pytorch-lightning

Форк
0
241 строка · 8.9 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import logging
15
import os
16
from argparse import Namespace
17
from unittest import mock
18
from unittest.mock import Mock
19

20
import numpy as np
21
import pytest
22
import torch
23
from lightning.fabric.loggers import TensorBoardLogger
24
from lightning.fabric.loggers.tensorboard import _TENSORBOARD_AVAILABLE
25
from lightning.fabric.wrappers import _FabricModule
26

27
from tests_fabric.test_fabric import BoringModel
28

29

30
def test_tensorboard_automatic_versioning(tmp_path):
31
    """Verify that automatic versioning works."""
32
    root_dir = tmp_path / "tb_versioning"
33
    root_dir.mkdir()
34
    (root_dir / "version_0").mkdir()
35
    (root_dir / "version_1").mkdir()
36
    (root_dir / "version_nonumber").mkdir()
37
    (root_dir / "other").mkdir()
38

39
    logger = TensorBoardLogger(root_dir=tmp_path, name="tb_versioning")
40
    assert logger.version == 2
41

42

43
def test_tensorboard_manual_versioning(tmp_path):
44
    """Verify that manual versioning works."""
45
    root_dir = tmp_path / "tb_versioning"
46
    root_dir.mkdir()
47
    (root_dir / "version_0").mkdir()
48
    (root_dir / "version_1").mkdir()
49
    (root_dir / "version_2").mkdir()
50

51
    logger = TensorBoardLogger(root_dir=tmp_path, name="tb_versioning", version=1)
52
    assert logger.version == 1
53

54

55
def test_tensorboard_named_version(tmp_path):
56
    """Verify that manual versioning works for string versions, e.g. '2020-02-05-162402'."""
57
    name = "tb_versioning"
58
    (tmp_path / name).mkdir()
59
    expected_version = "2020-02-05-162402"
60

61
    logger = TensorBoardLogger(root_dir=tmp_path, name=name, version=expected_version)
62
    logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5})  # Force data to be written
63

64
    assert logger.version == expected_version
65
    assert os.listdir(tmp_path / name) == [expected_version]
66
    assert os.listdir(tmp_path / name / expected_version)
67

68

69
@pytest.mark.parametrize("name", ["", None])
70
def test_tensorboard_no_name(tmp_path, name):
71
    """Verify that None or empty name works."""
72
    logger = TensorBoardLogger(root_dir=tmp_path, name=name)
73
    logger.log_hyperparams({"a": 1, "b": 2, 123: 3, 3.5: 4, 5j: 5})  # Force data to be written
74
    assert os.path.normpath(logger.root_dir) == str(tmp_path)  # use os.path.normpath to handle trailing /
75
    assert os.listdir(tmp_path / "version_0")
76

77

78
def test_tensorboard_log_sub_dir(tmp_path):
79
    # no sub_dir specified
80
    root_dir = tmp_path / "logs"
81
    logger = TensorBoardLogger(root_dir, name="name", version="version")
82
    assert logger.log_dir == os.path.join(root_dir, "name", "version")
83

84
    # sub_dir specified
85
    logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
86
    assert logger.log_dir == os.path.join(root_dir, "name", "version", "sub_dir")
87

88

89
def test_tensorboard_expand_home():
90
    """Test that the home dir (`~`) gets expanded properly."""
91
    root_dir = "~/tmp"
92
    explicit_root_dir = os.path.expanduser(root_dir)
93
    logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
94
    assert logger.root_dir == root_dir
95
    assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")
96

97

98
@mock.patch.dict(os.environ, {"TEST_ENV_DIR": "some_directory"})
99
def test_tensorboard_expand_env_vars():
100
    """Test that the env vars in path names (`$`) get handled properly."""
101
    test_env_dir = os.environ["TEST_ENV_DIR"]
102
    root_dir = "$TEST_ENV_DIR/tmp"
103
    explicit_root_dir = f"{test_env_dir}/tmp"
104
    logger = TensorBoardLogger(root_dir, name="name", version="version", sub_dir="sub_dir")
105
    assert logger.log_dir == os.path.join(explicit_root_dir, "name", "version", "sub_dir")
106

107

108
@pytest.mark.parametrize("step_idx", [10, None])
109
def test_tensorboard_log_metrics(tmp_path, step_idx):
110
    logger = TensorBoardLogger(tmp_path)
111
    metrics = {"float": 0.3, "int": 1, "FloatTensor": torch.tensor(0.1), "IntTensor": torch.tensor(1)}
112
    logger.log_metrics(metrics, step_idx)
113

114

115
def test_tensorboard_log_hyperparams(tmp_path):
116
    logger = TensorBoardLogger(tmp_path)
117
    hparams = {
118
        "float": 0.3,
119
        "int": 1,
120
        "string": "abc",
121
        "bool": True,
122
        "dict": {"a": {"b": "c"}},
123
        "list": [1, 2, 3],
124
        "namespace": Namespace(foo=Namespace(bar="buzz")),
125
        "layer": torch.nn.BatchNorm1d,
126
        "tensor": torch.empty(2, 2, 2),
127
        "array": np.empty([2, 2, 2]),
128
    }
129
    logger.log_hyperparams(hparams)
130

131

132
def test_tensorboard_log_hparams_and_metrics(tmp_path):
133
    logger = TensorBoardLogger(tmp_path, default_hp_metric=False)
134
    hparams = {
135
        "float": 0.3,
136
        "int": 1,
137
        "string": "abc",
138
        "bool": True,
139
        "dict": {"a": {"b": "c"}},
140
        "list": [1, 2, 3],
141
        "namespace": Namespace(foo=Namespace(bar="buzz")),
142
        "layer": torch.nn.BatchNorm1d,
143
        "tensor": torch.empty(2, 2, 2),
144
        "array": np.empty([2, 2, 2]),
145
    }
146
    metrics = {"abc": torch.tensor([0.54])}
147
    logger.log_hyperparams(hparams, metrics)
148

149

150
@pytest.mark.parametrize("example_input_array", [None, torch.rand(2, 32)])
151
def test_tensorboard_log_graph(tmp_path, example_input_array):
152
    """Test that log graph works with both model.example_input_array and if array is passed externally."""
153
    # TODO(fabric): Test both nn.Module and LightningModule
154
    # TODO(fabric): Assert _apply_batch_transfer_handler is calling the batch transfer hooks
155
    model = BoringModel()
156
    if example_input_array is not None:
157
        model.example_input_array = None
158

159
    logger = TensorBoardLogger(tmp_path)
160
    logger._experiment = Mock()
161
    logger.log_graph(model, example_input_array)
162
    if example_input_array is not None:
163
        logger.experiment.add_graph.assert_called_with(model, example_input_array)
164
    logger._experiment.reset_mock()
165

166
    # model wrapped in `FabricModule`
167
    wrapped = _FabricModule(model, strategy=Mock())
168
    logger.log_graph(wrapped, example_input_array)
169
    if example_input_array is not None:
170
        logger.experiment.add_graph.assert_called_with(model, example_input_array)
171

172

173
@pytest.mark.skipif(not _TENSORBOARD_AVAILABLE, reason=str(_TENSORBOARD_AVAILABLE))
174
def test_tensorboard_log_graph_warning_no_example_input_array(tmp_path):
175
    """Test that log graph throws warning if model.example_input_array is None."""
176
    model = BoringModel()
177
    model.example_input_array = None
178
    logger = TensorBoardLogger(tmp_path, log_graph=True)
179
    with pytest.warns(
180
        UserWarning,
181
        match="Could not log computational graph to TensorBoard: The `model.example_input_array` .* was not given",
182
    ):
183
        logger.log_graph(model)
184

185
    model.example_input_array = {"x": 1, "y": 2}
186
    with pytest.warns(
187
        UserWarning, match="Could not log computational graph to TensorBoard: .* can't be traced by TensorBoard"
188
    ):
189
        logger.log_graph(model)
190

191

192
def test_tensorboard_finalize(monkeypatch, tmp_path):
193
    """Test that the SummaryWriter closes in finalize."""
194
    if _TENSORBOARD_AVAILABLE:
195
        import torch.utils.tensorboard as tb
196
    else:
197
        import tensorboardX as tb
198

199
    monkeypatch.setattr(tb, "SummaryWriter", Mock())
200
    logger = TensorBoardLogger(root_dir=tmp_path)
201
    assert logger._experiment is None
202
    logger.finalize("any")
203

204
    # no log calls, no experiment created -> nothing to flush
205
    logger.experiment.assert_not_called()
206

207
    logger = TensorBoardLogger(root_dir=tmp_path)
208
    logger.log_metrics({"flush_me": 11.1})  # trigger creation of an experiment
209
    logger.finalize("any")
210

211
    # finalize flushes to experiment directory
212
    logger.experiment.flush.assert_called()
213
    logger.experiment.close.assert_called()
214

215

216
@mock.patch("lightning.fabric.loggers.tensorboard.log")
217
def test_tensorboard_with_symlink(log, tmp_path, monkeypatch):
218
    """Tests a specific failure case when tensorboard logger is used with empty name, symbolic link ``save_dir``, and
219
    relative paths."""
220
    monkeypatch.chdir(tmp_path)  # need to use relative paths
221
    source = os.path.join(".", "lightning_logs")
222
    dest = os.path.join(".", "sym_lightning_logs")
223

224
    os.makedirs(source, exist_ok=True)
225
    os.symlink(source, dest)
226

227
    logger = TensorBoardLogger(root_dir=dest, name="")
228
    _ = logger.version
229

230
    log.warning.assert_not_called()
231

232

233
def test_tensorboard_missing_folder_warning(tmp_path, caplog):
234
    """Verify that the logger throws a warning for invalid directory."""
235
    name = "fake_dir"
236
    logger = TensorBoardLogger(root_dir=tmp_path, name=name)
237

238
    with caplog.at_level(logging.WARNING):
239
        assert logger.version == 0
240

241
    assert "Missing logger folder:" in caplog.text
242

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.