simpletransformers
105 строк · 3.3 Кб
1import json
2import logging
3import os
4
5from tqdm.auto import tqdm
6
7from simpletransformers.question_answering import QuestionAnsweringModel
8
9logging.basicConfig(level=logging.INFO)
10transformers_logger = logging.getLogger("transformers")
11transformers_logger.setLevel(logging.WARNING)
12
13# Create dummy data to use for training.
14train_data = [
15{
16"context": "This is the first context",
17"qas": [
18{
19"id": "00001",
20"is_impossible": False,
21"question": "Which context is this?",
22"answers": [{"text": "the first", "answer_start": 8}],
23}
24],
25},
26{
27"context": "Other legislation followed, including the Migratory Bird Conservation Act of 1929, a 1937 treaty prohibiting the hunting of right and gray whales,\
28and the Bald Eagle Protection Act of 1940. These later laws had a low cost to society—the species were relatively rare—and little opposition was raised",
29"qas": [
30{
31"id": "00002",
32"is_impossible": False,
33"question": "What was the cost to society?",
34"answers": [{"text": "low cost", "answer_start": 225}],
35},
36{
37"id": "00003",
38"is_impossible": False,
39"question": "What was the name of the 1937 treaty?",
40"answers": [{"text": "Bald Eagle Protection Act", "answer_start": 167}],
41},
42{
43"id": "00004",
44"is_impossible": True,
45"question": "How did Alexandar Hamilton die?",
46"answers": [],
47},
48],
49},
50] # noqa: ignore flake8"
51
52for i in range(20):
53train_data.extend(train_data)
54
55# Save as a JSON file
56os.makedirs("data", exist_ok=True)
57with open("data/train.json", "w") as f:
58json.dump(train_data, f)
59
60# Save as a JSONL file
61with open("data/train.jsonl", "w") as outfile:
62for entry in tqdm(train_data):
63json.dump(entry, outfile)
64outfile.write("\n")
65
66train_args = {
67"reprocess_input_data": True,
68"overwrite_output_dir": True,
69"evaluate_during_training": True,
70"evaluate_during_training_steps": 10000,
71"train_batch_size": 8,
72"num_train_epochs": 1,
73# 'wandb_project': 'test-new-project',
74# "use_early_stopping": True,
75"n_best_size": 3,
76"fp16": False,
77"no_save": True,
78"manual_seed": 4,
79"max_seq_length": 512,
80"no_save": True,
81"n_best_size": 10,
82"lazy_loading": True,
83# "use_multiprocessing": False,
84}
85
86# Create the QuestionAnsweringModel
87model = QuestionAnsweringModel(
88"bert", "bert-base-cased", args=train_args, use_cuda=True, cuda_device=0
89)
90
91# Train the model with JSON file
92model.train_model("data/train.jsonl", eval_data="data/train.json")
93
94# Making predictions using the model.
95to_predict = [
96{
97"context": "Other legislation followed, including the Migratory Bird Conservation Act of 1929, a 1937 treaty prohibiting the hunting of right and gray whales,\
98and the Bald Eagle Protection Act of 1940. These later laws had a low cost to society—the species were relatively rare—and little opposition was raised",
99"qas": [{"question": "What was the name of the 1937 treaty?", "id": "0"}],
100}
101]
102
103print(model.predict(to_predict, n_best_size=2))
104
105# flake8: noqa
106