pytorch-lightning
50 строк · 1.5 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Any, Dict, Optional16
17
18class AttributeDict(Dict):19"""Extended dictionary accessible with dot notation.20
21>>> ad = AttributeDict({'key1': 1, 'key2': 'abc'})
22>>> ad.key1
231
24>>> ad.update({'my-key': 3.14})
25>>> ad.update(new_key=42)
26>>> ad.key1 = 2
27>>> ad
28"key1": 2
29"key2": abc
30"my-key": 3.14
31"new_key": 42
32
33"""
34
35def __getattr__(self, key: str) -> Optional[Any]:36try:37return self[key]38except KeyError as exp:39raise AttributeError(f'Missing attribute "{key}"') from exp40
41def __setattr__(self, key: str, val: Any) -> None:42self[key] = val43
44def __repr__(self) -> str:45if not len(self):46return ""47max_key_length = max(len(str(k)) for k in self)48tmp_name = "{:" + str(max_key_length + 3) + "s} {}"49rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())]50return "\n".join(rows)51