paddlenlp

Форк
0
/
test_serialization.py 
94 строки · 3.2 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16
import tempfile
17
from unittest import TestCase
18

19
import numpy as np
20
import paddle
21
from huggingface_hub import hf_hub_download
22
from parameterized import parameterized
23

24
from paddlenlp.utils import load_torch
25
from tests.testing_utils import require_package
26

27

28
class SerializationTest(TestCase):
29
    @parameterized.expand(
30
        [
31
            "float32",
32
            "float16",
33
            "bfloat16",
34
        ]
35
    )
36
    @require_package("torch")
37
    def test_simple_load(self, dtype: str):
38
        import torch
39

40
        # torch "normal_kernel_cpu" not implemented for 'Char', 'Int', 'Long', so only support float
41
        dtype_mapping = {
42
            "float32": torch.float32,
43
            "float16": torch.float16,
44
            "bfloat16": torch.bfloat16,  # test bfloat16
45
        }
46
        dtype = dtype_mapping[dtype]
47

48
        with tempfile.TemporaryDirectory() as tempdir:
49
            weight_file_path = os.path.join(tempdir, "pytorch_model.bin")
50
            torch.save(
51
                {
52
                    "a": torch.randn(2, 3, dtype=dtype),
53
                    "b": torch.randn(3, 4, dtype=dtype),
54
                    "a_parameter": torch.nn.Parameter(torch.randn(2, 3, dtype=dtype)),  # test torch.nn.Parameter
55
                    "b_parameter": torch.nn.Parameter(torch.randn(3, 4, dtype=dtype)),
56
                },
57
                weight_file_path,
58
            )
59
            numpy_data = load_torch(weight_file_path)
60
            torch_data = torch.load(weight_file_path)
61

62
            for key, arr in numpy_data.items():
63
                assert np.allclose(
64
                    paddle.to_tensor(arr).cast("float32").cpu().numpy(),
65
                    torch_data[key].detach().cpu().to(torch.float32).numpy(),
66
                )
67

68
    @parameterized.expand(
69
        [
70
            "hf-internal-testing/tiny-random-codegen",
71
            "hf-internal-testing/tiny-random-Data2VecTextModel",
72
            "hf-internal-testing/tiny-random-SwinModel",
73
        ]
74
    )
75
    @require_package("torch")
76
    def test_load_bert_model(self, repo_id):
77
        import torch
78

79
        with tempfile.TemporaryDirectory() as tempdir:
80
            weight_file = hf_hub_download(
81
                repo_id=repo_id,
82
                filename="pytorch_model.bin",
83
                cache_dir=tempdir,
84
                library_name="PaddleNLP",
85
            )
86
            torch_weight = torch.load(weight_file)
87
            torch_weight = {key: value for key, value in torch_weight.items()}
88
            paddle_weight = load_torch(weight_file)
89

90
            for key, arr in paddle_weight.items():
91
                assert np.allclose(
92
                    arr,
93
                    torch_weight[key].numpy(),
94
                )
95

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.