1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
17
from unittest import TestCase
21
from huggingface_hub import hf_hub_download
22
from parameterized import parameterized
24
from paddlenlp.utils import load_torch
25
from tests.testing_utils import require_package
28
class SerializationTest(TestCase):
29
@parameterized.expand(
36
@require_package("torch")
37
def test_simple_load(self, dtype: str):
40
# torch "normal_kernel_cpu" not implemented for 'Char', 'Int', 'Long', so only support float
42
"float32": torch.float32,
43
"float16": torch.float16,
44
"bfloat16": torch.bfloat16, # test bfloat16
46
dtype = dtype_mapping[dtype]
48
with tempfile.TemporaryDirectory() as tempdir:
49
weight_file_path = os.path.join(tempdir, "pytorch_model.bin")
52
"a": torch.randn(2, 3, dtype=dtype),
53
"b": torch.randn(3, 4, dtype=dtype),
54
"a_parameter": torch.nn.Parameter(torch.randn(2, 3, dtype=dtype)), # test torch.nn.Parameter
55
"b_parameter": torch.nn.Parameter(torch.randn(3, 4, dtype=dtype)),
59
numpy_data = load_torch(weight_file_path)
60
torch_data = torch.load(weight_file_path)
62
for key, arr in numpy_data.items():
64
paddle.to_tensor(arr).cast("float32").cpu().numpy(),
65
torch_data[key].detach().cpu().to(torch.float32).numpy(),
68
@parameterized.expand(
70
"hf-internal-testing/tiny-random-codegen",
71
"hf-internal-testing/tiny-random-Data2VecTextModel",
72
"hf-internal-testing/tiny-random-SwinModel",
75
@require_package("torch")
76
def test_load_bert_model(self, repo_id):
79
with tempfile.TemporaryDirectory() as tempdir:
80
weight_file = hf_hub_download(
82
filename="pytorch_model.bin",
84
library_name="PaddleNLP",
86
torch_weight = torch.load(weight_file)
87
torch_weight = {key: value for key, value in torch_weight.items()}
88
paddle_weight = load_torch(weight_file)
90
for key, arr in paddle_weight.items():
93
torch_weight[key].numpy(),