paddlenlp
62 строки · 2.2 Кб
1# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse
16import os
17
18import paddle
19from model import PointwiseMatching
20
21from paddlenlp.transformers import AutoModel
22
23if __name__ == "__main__":
24parser = argparse.ArgumentParser()
25parser.add_argument(
26"--params_path",
27type=str,
28required=True,
29default="./checkpoint/model_900/model_state.pdparams",
30help="The path to model parameters to be loaded.",
31)
32parser.add_argument(
33"--output_path", type=str, default="./output", help="The path of model parameter in static graph to be saved."
34)
35args = parser.parse_args()
36
37pretrained_model = AutoModel.from_pretrained("ernie-3.0-medium-zh")
38model = PointwiseMatching(pretrained_model)
39
40if args.params_path:
41if os.path.isfile(args.params_path):
42state_dict = paddle.load(args.params_path)
43model.set_dict(state_dict)
44print("Loaded parameters from %s" % args.params_path)
45elif os.path.isdir(args.params_path):
46path = os.path.join(args.params_path, "model_state.pdparams")
47state_dict = paddle.load(path)
48model.set_dict(state_dict)
49print("Loaded parameters from %s" % path)
50model.eval()
51
52# Convert to static graph with specific input description
53model = paddle.jit.to_static(
54model,
55input_spec=[
56paddle.static.InputSpec(shape=[None, None], dtype="int64"), # input_ids
57paddle.static.InputSpec(shape=[None, None], dtype="int64"), # segment_ids
58],
59)
60# Save in static graph model.
61save_path = os.path.join(args.output_path, "inference")
62paddle.jit.save(model, save_path)
63