pytorch-lightning
75 строк · 2.6 Кб
1# Copyright The Lightning AI team.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import json
16from typing import Any, Dict
17
18
19def _duplicate_checker(js):
20"""_duplicate_checker verifies that your JSON object doesn't contain duplicate keys."""
21result = {}
22for name, value in js:
23if name in result:
24raise ValueError(
25f"Unable to load JSON. A duplicate key {name} was detected. JSON objects must have unique keys."
26)
27result[name] = value
28return result
29
30
31def string2dict(text):
32"""String2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
33if not isinstance(text, str):
34text = text.decode("utf-8")
35try:
36js = json.loads(text, object_pairs_hook=_duplicate_checker)
37return js
38except ValueError as ex:
39raise ValueError(f"Unable to load JSON: {str(ex)}.")
40
41
42def is_openapi(obj):
43"""is_openopi checks if an object was generated by OpenAPI."""
44return hasattr(obj, "swagger_types")
45
46
47def create_openapi_object(json_obj: Dict, target: Any):
48"""Create the OpenAPI object from the given JSON dict and based on the target object.
49
50Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid
51object.
52
53"""
54if not isinstance(json_obj, dict):
55raise TypeError("json_obj must be a dictionary")
56if not is_openapi(target):
57raise TypeError("target must be an openapi object")
58
59target_attribs = {}
60for key, value in json_obj.items():
61try:
62# user provided key is not a valid key on openapi object
63sub_target = getattr(target, key)
64except AttributeError:
65raise ValueError(f"Field {key} not found in the target object")
66
67if is_openapi(sub_target): # it's an openapi object
68target_attribs[key] = create_openapi_object(value, sub_target)
69else:
70target_attribs[key] = value
71
72# TODO(sherin) - specifically process list and dict and do the validation. Also do the
73# verification for enum types
74
75return target.__class__(**target_attribs)
76