pytorch-lightning

Форк
0
75 строк · 2.6 Кб
1
# Copyright The Lightning AI team.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import json
16
from typing import Any, Dict
17

18

19
def _duplicate_checker(js):
20
    """_duplicate_checker verifies that your JSON object doesn't contain duplicate keys."""
21
    result = {}
22
    for name, value in js:
23
        if name in result:
24
            raise ValueError(
25
                f"Unable to load JSON. A duplicate key {name} was detected. JSON objects must have unique keys."
26
            )
27
        result[name] = value
28
    return result
29

30

31
def string2dict(text):
32
    """String2dict parses a JSON string into a dictionary, ensuring no keys are duplicated by accident."""
33
    if not isinstance(text, str):
34
        text = text.decode("utf-8")
35
    try:
36
        js = json.loads(text, object_pairs_hook=_duplicate_checker)
37
        return js
38
    except ValueError as ex:
39
        raise ValueError(f"Unable to load JSON: {str(ex)}.")
40

41

42
def is_openapi(obj):
43
    """is_openopi checks if an object was generated by OpenAPI."""
44
    return hasattr(obj, "swagger_types")
45

46

47
def create_openapi_object(json_obj: Dict, target: Any):
48
    """Create the OpenAPI object from the given JSON dict and based on the target object.
49

50
    Lightning AI uses the target object to make new objects from the given JSON spec so the target must be a valid
51
    object.
52

53
    """
54
    if not isinstance(json_obj, dict):
55
        raise TypeError("json_obj must be a dictionary")
56
    if not is_openapi(target):
57
        raise TypeError("target must be an openapi object")
58

59
    target_attribs = {}
60
    for key, value in json_obj.items():
61
        try:
62
            # user provided key is not a valid key on openapi object
63
            sub_target = getattr(target, key)
64
        except AttributeError:
65
            raise ValueError(f"Field {key} not found in the target object")
66

67
        if is_openapi(sub_target):  # it's an openapi object
68
            target_attribs[key] = create_openapi_object(value, sub_target)
69
        else:
70
            target_attribs[key] = value
71

72
        # TODO(sherin) - specifically process list and dict and do the validation. Also do the
73
        #  verification for enum types
74

75
    return target.__class__(**target_attribs)
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.