pytorch-lightning

Форк
0
169 строк · 5.2 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from argparse import Namespace
16
from dataclasses import dataclass
17

18
import numpy as np
19
import torch
20
from lightning.fabric.utilities.logger import (
21
    _add_prefix,
22
    _convert_params,
23
    _flatten_dict,
24
    _sanitize_callable_params,
25
    _sanitize_params,
26
)
27

28

29
def test_convert_params():
30
    """Test conversion of params to a dict."""
31
    # Test normal dict, make sure it is unchanged
32
    params = {"string": "string", "int": 1, "float": 0.1, "bool": True, "none": None}
33
    expected = params.copy()
34
    assert _convert_params(params) == expected
35

36
    # Test None conversion
37
    assert _convert_params(None) == {}
38

39
    # Test conversion of argparse Namespace
40
    params = Namespace(string="string", int=1, float=0.1, bool=True, none=None)
41
    expected = vars(params)
42
    assert _convert_params(params) == expected
43

44

45
def test_flatten_dict():
46
    """Validate flatten_dict can handle nested dictionaries and argparse Namespace."""
47
    # Test basic dict flattening with custom delimiter
48
    params = {"a": {"b": "c"}}
49
    params = _flatten_dict(params, "--")
50

51
    assert "a" not in params
52
    assert params["a--b"] == "c"
53

54
    # Test complex nested dict flattening
55
    params = {"a": {5: {"foo": "bar"}}, "b": 6, "c": {7: [1, 2, 3, 4], 8: "foo", 9: {10: "bar"}}}
56
    params = _flatten_dict(params)
57

58
    assert "a" not in params
59
    assert params["a/5/foo"] == "bar"
60
    assert params["b"] == 6
61
    assert params["c/7"] == [1, 2, 3, 4]
62
    assert params["c/8"] == "foo"
63
    assert params["c/9/10"] == "bar"
64

65
    # Test flattening of argparse Namespace
66
    params = Namespace(a=1, b=2)
67
    wrapping_dict = {"params": params}
68
    params = _flatten_dict(wrapping_dict)
69

70
    params_type = type(params)  # way around needed for Ruff's `isinstance` suggestion
71
    assert params_type is dict
72
    assert params["params/a"] == 1
73
    assert params["params/b"] == 2
74
    assert "a" not in params
75
    assert "b" not in params
76

77
    # Test flattening of dataclass objects
78
    @dataclass
79
    class A:
80
        c: int
81
        d: int
82

83
    @dataclass
84
    class B:
85
        a: A
86
        b: int
87

88
    params = {"params": B(a=A(c=1, d=2), b=3), "param": 4}
89
    params = _flatten_dict(params)
90
    assert params == {"param": 4, "params/b": 3, "params/a/c": 1, "params/a/d": 2}
91

92

93
def test_sanitize_callable_params():
94
    """Callback function are not serializiable.
95

96
    Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
97

98
    """
99

100
    def return_something():
101
        return "something"
102

103
    def wrapper_something():
104
        return return_something
105

106
    params = Namespace(
107
        foo="bar",
108
        something=return_something,
109
        wrapper_something_wo_name=(lambda: lambda: "1"),
110
        wrapper_something=wrapper_something,
111
    )
112

113
    params = _convert_params(params)
114
    params = _flatten_dict(params)
115
    params = _sanitize_callable_params(params)
116
    assert params["foo"] == "bar"
117
    assert params["something"] == "something"
118
    assert params["wrapper_something"] == "wrapper_something"
119
    assert params["wrapper_something_wo_name"] == "<lambda>"
120

121

122
def test_sanitize_params():
123
    """Verify sanitize params converts various types to loggable strings."""
124
    params = {
125
        "float": 0.3,
126
        "int": 1,
127
        "string": "abc",
128
        "bool": True,
129
        "list": [1, 2, 3],
130
        "np_bool": np.bool_(False),
131
        "np_int": np.int_(5),
132
        "np_double": np.double(3.14159),
133
        "namespace": Namespace(foo=3),
134
        "layer": torch.nn.BatchNorm1d,
135
        "tensor": torch.ones(3),
136
    }
137
    params = _sanitize_params(params)
138

139
    assert params["float"] == 0.3
140
    assert params["int"] == 1
141
    assert params["string"] == "abc"
142
    assert params["bool"] is True
143
    assert params["list"] == "[1, 2, 3]"
144
    assert params["np_bool"] is False
145
    assert params["np_int"] == 5
146
    assert params["np_double"] == 3.14159
147
    assert params["namespace"] == "Namespace(foo=3)"
148
    assert params["layer"] == "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>"
149
    assert torch.equal(params["tensor"], torch.ones(3))
150

151

152
def test_add_prefix():
153
    """Verify add_prefix modifies the dict keys correctly."""
154
    metrics = {"metric1": 1, "metric2": 2}
155
    metrics = _add_prefix(metrics, "prefix", "-")
156

157
    assert "prefix-metric1" in metrics
158
    assert "prefix-metric2" in metrics
159
    assert "metric1" not in metrics
160
    assert "metric2" not in metrics
161

162
    metrics = _add_prefix(metrics, "prefix2", "_")
163

164
    assert "prefix2_prefix-metric1" in metrics
165
    assert "prefix2_prefix-metric2" in metrics
166
    assert "prefix-metric1" not in metrics
167
    assert "prefix-metric2" not in metrics
168
    assert metrics["prefix2_prefix-metric1"] == 1
169
    assert metrics["prefix2_prefix-metric2"] == 2
170

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.