pytorch-lightning
169 строк · 5.2 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from argparse import Namespace
16from dataclasses import dataclass
17
18import numpy as np
19import torch
20from lightning.fabric.utilities.logger import (
21_add_prefix,
22_convert_params,
23_flatten_dict,
24_sanitize_callable_params,
25_sanitize_params,
26)
27
28
29def test_convert_params():
30"""Test conversion of params to a dict."""
31# Test normal dict, make sure it is unchanged
32params = {"string": "string", "int": 1, "float": 0.1, "bool": True, "none": None}
33expected = params.copy()
34assert _convert_params(params) == expected
35
36# Test None conversion
37assert _convert_params(None) == {}
38
39# Test conversion of argparse Namespace
40params = Namespace(string="string", int=1, float=0.1, bool=True, none=None)
41expected = vars(params)
42assert _convert_params(params) == expected
43
44
45def test_flatten_dict():
46"""Validate flatten_dict can handle nested dictionaries and argparse Namespace."""
47# Test basic dict flattening with custom delimiter
48params = {"a": {"b": "c"}}
49params = _flatten_dict(params, "--")
50
51assert "a" not in params
52assert params["a--b"] == "c"
53
54# Test complex nested dict flattening
55params = {"a": {5: {"foo": "bar"}}, "b": 6, "c": {7: [1, 2, 3, 4], 8: "foo", 9: {10: "bar"}}}
56params = _flatten_dict(params)
57
58assert "a" not in params
59assert params["a/5/foo"] == "bar"
60assert params["b"] == 6
61assert params["c/7"] == [1, 2, 3, 4]
62assert params["c/8"] == "foo"
63assert params["c/9/10"] == "bar"
64
65# Test flattening of argparse Namespace
66params = Namespace(a=1, b=2)
67wrapping_dict = {"params": params}
68params = _flatten_dict(wrapping_dict)
69
70params_type = type(params) # way around needed for Ruff's `isinstance` suggestion
71assert params_type is dict
72assert params["params/a"] == 1
73assert params["params/b"] == 2
74assert "a" not in params
75assert "b" not in params
76
77# Test flattening of dataclass objects
78@dataclass
79class A:
80c: int
81d: int
82
83@dataclass
84class B:
85a: A
86b: int
87
88params = {"params": B(a=A(c=1, d=2), b=3), "param": 4}
89params = _flatten_dict(params)
90assert params == {"param": 4, "params/b": 3, "params/a/c": 1, "params/a/d": 2}
91
92
93def test_sanitize_callable_params():
94"""Callback function are not serializiable.
95
96Therefore, we get them a chance to return something and if the returned type is not accepted, return None.
97
98"""
99
100def return_something():
101return "something"
102
103def wrapper_something():
104return return_something
105
106params = Namespace(
107foo="bar",
108something=return_something,
109wrapper_something_wo_name=(lambda: lambda: "1"),
110wrapper_something=wrapper_something,
111)
112
113params = _convert_params(params)
114params = _flatten_dict(params)
115params = _sanitize_callable_params(params)
116assert params["foo"] == "bar"
117assert params["something"] == "something"
118assert params["wrapper_something"] == "wrapper_something"
119assert params["wrapper_something_wo_name"] == "<lambda>"
120
121
122def test_sanitize_params():
123"""Verify sanitize params converts various types to loggable strings."""
124params = {
125"float": 0.3,
126"int": 1,
127"string": "abc",
128"bool": True,
129"list": [1, 2, 3],
130"np_bool": np.bool_(False),
131"np_int": np.int_(5),
132"np_double": np.double(3.14159),
133"namespace": Namespace(foo=3),
134"layer": torch.nn.BatchNorm1d,
135"tensor": torch.ones(3),
136}
137params = _sanitize_params(params)
138
139assert params["float"] == 0.3
140assert params["int"] == 1
141assert params["string"] == "abc"
142assert params["bool"] is True
143assert params["list"] == "[1, 2, 3]"
144assert params["np_bool"] is False
145assert params["np_int"] == 5
146assert params["np_double"] == 3.14159
147assert params["namespace"] == "Namespace(foo=3)"
148assert params["layer"] == "<class 'torch.nn.modules.batchnorm.BatchNorm1d'>"
149assert torch.equal(params["tensor"], torch.ones(3))
150
151
152def test_add_prefix():
153"""Verify add_prefix modifies the dict keys correctly."""
154metrics = {"metric1": 1, "metric2": 2}
155metrics = _add_prefix(metrics, "prefix", "-")
156
157assert "prefix-metric1" in metrics
158assert "prefix-metric2" in metrics
159assert "metric1" not in metrics
160assert "metric2" not in metrics
161
162metrics = _add_prefix(metrics, "prefix2", "_")
163
164assert "prefix2_prefix-metric1" in metrics
165assert "prefix2_prefix-metric2" in metrics
166assert "prefix-metric1" not in metrics
167assert "prefix-metric2" not in metrics
168assert metrics["prefix2_prefix-metric1"] == 1
169assert metrics["prefix2_prefix-metric2"] == 2
170