paddlenlp
52 строки · 1.8 Кб
1# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import argparse
16import os
17
18import paddle
19
20from paddlenlp.transformers import AutoModelForSequenceClassification
21
22if __name__ == "__main__":
23parser = argparse.ArgumentParser()
24parser.add_argument(
25"--params_path",
26type=str,
27required=True,
28default="./checkpoint/model_900",
29help="The path to model parameters to be loaded.",
30)
31parser.add_argument(
32"--output_path", type=str, default="./output", help="The path of model parameter in static graph to be saved."
33)
34args = parser.parse_args()
35
36# The number of labels should be in accordance with the training dataset.
37label_map = {0: "negative", 1: "positive"}
38model = AutoModelForSequenceClassification.from_pretrained(args.params_path, num_labels=len(label_map))
39
40model.eval()
41
42# Convert to static graph with specific input description
43model = paddle.jit.to_static(
44model,
45input_spec=[
46paddle.static.InputSpec(shape=[None, None], dtype="int64"), # input_ids
47paddle.static.InputSpec(shape=[None, None], dtype="int64"), # segment_ids
48],
49)
50# Save in static graph model.
51save_path = os.path.join(args.output_path, "inference")
52paddle.jit.save(model, save_path)
53