paddlenlp

Форк
0
788 строк · 25.0 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import os
16

17
import numpy as np
18
import paddle
19
from ppfleetx.data.tokenizers import GPTTokenizer
20
from ppfleetx.utils.download import cached_path
21
from ppfleetx.utils.file import parse_csv, unzip
22

23
__all__ = ["CoLA", "SST2", "MNLI", "QNLI", "RTE", "WNLI", "MRPC", "QQP", "STSB"]
24
"""
25

26
Single-Sentence Tasks:
27
* CoLA
28
* SST-2
29

30

31
Similarity and Paraphrase Tasks:
32
* MRPC
33
* STS-B
34
* QQP
35

36

37
Inference Tasks:
38
* MNLI
39
* QNLI
40
* RTE
41
* WNLI
42
"""
43

44

45
class CoLA(paddle.io.Dataset):
46
    """The Corpus of Linguistic Acceptability consists of English
47
    acceptability judgments drawn from books and journal articles on
48
    linguistic theory. Each example is a sequence of words annotated
49
    with whether it is a grammatical English sentence."""
50

51
    # ref https://pytorch.org/text/stable/_modules/torchtext/datasets/cola.html#CoLA
52

53
    URL = "https://nyu-mll.github.io/CoLA/cola_public_1.1.zip"
54
    MD5 = "9f6d88c3558ec424cd9d66ea03589aba"
55

56
    NUM_LINES = {
57
        "train": 8551,
58
        "dev": 527,
59
        "test": 516,
60
    }
61

62
    _PATH = "cola_public_1.1.zip"
63

64
    DATASET_NAME = "CoLA"
65

66
    _EXTRACTED_FILES = {
67
        "train": os.path.join("raw", "in_domain_train.tsv"),
68
        "dev": os.path.join("raw", "in_domain_dev.tsv"),
69
        "test": os.path.join("raw", "out_of_domain_dev.tsv"),
70
    }
71

72
    def __init__(self, root, split, max_length=128):
73

74
        self.root = root
75
        self.split = split
76
        if os.path.exists(self.root):
77
            assert os.path.isdir(self.root)
78
        else:
79
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
80
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
81

82
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
83
        assert os.path.exists(self.path), f"{self.path} is not exists!"
84
        self.max_length = max_length
85

86
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
87

88
        assert split in ["train", "dev", "test"]
89

90
        def _filter_res(x):
91
            return len(x) == 4
92

93
        def _modify_res(x):
94
            return (x[3], int(x[1]))
95

96
        self.samples = parse_csv(
97
            self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res, filter_funcs=_filter_res
98
        )
99

100
    def __getitem__(self, idx):
101
        sample = self.samples[idx]
102

103
        encoded_inputs = self.tokenizer(
104
            sample[0],
105
            padding="max_length",
106
            truncation="longest_first",
107
            max_length=self.max_length,
108
            return_token_type_ids=False,
109
        )
110
        input_ids = encoded_inputs["input_ids"]
111
        input_ids = paddle.to_tensor(input_ids)
112
        if self.split != "test":
113
            return input_ids, sample[1]
114
        else:
115
            return input_ids
116

117
    def __len__(self):
118
        return len(self.samples)
119

120
    @property
121
    def class_num(self):
122
        return 2
123

124

125
class SST2(paddle.io.Dataset):
126
    """The Stanford Sentiment Treebank consists of sentences from movie reviews and
127
    human annotations of their sentiment. The task is to predict the sentiment of a
128
    given sentence. We use the two-way (positive/negative) class split, and use only
129
    sentence-level labels."""
130

131
    # ref https://pytorch.org/text/stable/_modules/torchtext/datasets/sst2.html#SST2
132

133
    URL = "https://dl.fbaipublicfiles.com/glue/data/SST-2.zip"
134
    MD5 = "9f81648d4199384278b86e315dac217c"
135

136
    NUM_LINES = {
137
        "train": 67349,
138
        "dev": 872,
139
        "test": 1821,
140
    }
141

142
    _PATH = "SST-2.zip"
143

144
    DATASET_NAME = "SST2"
145

146
    _EXTRACTED_FILES = {
147
        "train": "train.tsv",
148
        "dev": "dev.tsv",
149
        "test": "test.tsv",
150
    }
151

152
    def __init__(self, root, split, max_length=128):
153

154
        self.root = root
155
        self.split = split
156
        if os.path.exists(self.root):
157
            assert os.path.isdir(self.root)
158
        else:
159
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
160
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
161

162
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
163
        assert os.path.exists(self.path), f"{self.path} is not exists!"
164
        self.max_length = max_length
165

166
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
167

168
        assert split in ["train", "dev", "test"]
169

170
        # test split for SST2 doesn't have labels
171
        if split == "test":
172

173
            def _modify_test_res(t):
174
                return (t[1].strip(),)
175

176
            self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_test_res)
177
        else:
178

179
            def _modify_res(t):
180
                return (t[0].strip(), int(t[1]))
181

182
            self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res)
183

184
    def __getitem__(self, idx):
185
        sample = self.samples[idx]
186

187
        encoded_inputs = self.tokenizer(
188
            sample[0],
189
            padding="max_length",
190
            truncation="longest_first",
191
            max_length=self.max_length,
192
            return_token_type_ids=False,
193
        )
194
        input_ids = encoded_inputs["input_ids"]
195
        input_ids = paddle.to_tensor(input_ids)
196
        if self.split != "test":
197
            return input_ids, sample[1]
198
        else:
199
            return input_ids
200

201
    def __len__(self):
202
        return len(self.samples)
203

204
    @property
205
    def class_num(self):
206
        return 2
207

208

209
class MNLI(paddle.io.Dataset):
210
    """The Multi-Genre Natural Language Inference Corpus is a crowdsourced
211
    collection of sentence pairs with textual entailment annotations. Given a premise sentence
212
    and a hypothesis sentence, the task is to predict whether the premise entails the hypothesis
213
    (entailment), contradicts the hypothesis (contradiction), or neither (neutral). The premise sentences are
214
    gathered from ten different sources, including transcribed speech, fiction, and government reports.
215
    We use the standard test set, for which we obtained private labels from the authors, and evaluate
216
    on both the matched (in-domain) and mismatched (cross-domain) section. We also use and recommend
217
    the SNLI corpus as 550k examples of auxiliary training data."""
218

219
    # ref https://pytorch.org/text/stable/_modules/torchtext/datasets/mnli.html#MNLI
220

221
    URL = "https://cims.nyu.edu/~sbowman/multinli/multinli_1.0.zip"
222
    MD5 = "0f70aaf66293b3c088a864891db51353"
223

224
    NUM_LINES = {
225
        "train": 392702,
226
        "dev_matched": 9815,
227
        "dev_mismatched": 9832,
228
    }
229

230
    _PATH = "multinli_1.0.zip"
231

232
    DATASET_NAME = "MNLI"
233

234
    _EXTRACTED_FILES = {
235
        "train": "multinli_1.0_train.txt",
236
        "dev_matched": "multinli_1.0_dev_matched.txt",
237
        "dev_mismatched": "multinli_1.0_dev_mismatched.txt",
238
    }
239

240
    LABEL_TO_INT = {"entailment": 0, "neutral": 1, "contradiction": 2}
241

242
    def __init__(self, root, split, max_length=128):
243

244
        self.root = root
245
        self.split = split
246
        if os.path.exists(self.root):
247
            assert os.path.isdir(self.root)
248
        else:
249
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
250
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
251

252
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
253
        assert os.path.exists(self.path), f"{self.path} is not exists!"
254
        self.max_length = max_length
255

256
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
257

258
        assert split in ["train", "dev_matched", "dev_mismatched"]
259

260
        def _filter_res(x):
261
            return x[0] in self.LABEL_TO_INT
262

263
        def _modify_res(x):
264
            return (x[5], x[6], self.LABEL_TO_INT[x[0]])
265

266
        self.samples = parse_csv(
267
            self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res, filter_funcs=_filter_res
268
        )
269

270
    def __getitem__(self, idx):
271
        sample = self.samples[idx]
272

273
        encoded_inputs = self.tokenizer(
274
            sample[0],
275
            text_pair=sample[1],
276
            padding="max_length",
277
            truncation="longest_first",
278
            max_length=self.max_length,
279
            return_token_type_ids=False,
280
        )
281
        input_ids = encoded_inputs["input_ids"]
282
        input_ids = paddle.to_tensor(input_ids)
283
        return input_ids, sample[2]
284

285
    def __len__(self):
286
        return len(self.samples)
287

288
    @property
289
    def class_num(self):
290
        return 3
291

292

293
class QNLI(paddle.io.Dataset):
294
    """The Stanford Question Answering Dataset is a question-answering
295
    dataset consisting of question-paragraph pairs, where one of the sentences in the paragraph (drawn
296
    from Wikipedia) contains the answer to the corresponding question (written by an annotator). We
297
    convert the task into sentence pair classification by forming a pair between each question and each
298
    sentence in the corresponding context, and filtering out pairs with low lexical overlap between the
299
    question and the context sentence. The task is to determine whether the context sentence contains
300
    the answer to the question. This modified version of the original task removes the requirement that
301
    the model select the exact answer, but also removes the simplifying assumptions that the answer
302
    is always present in the input and that lexical overlap is a reliable cue."""
303

304
    # ref https://pytorch.org/text/stable/_modules/torchtext/datasets/qnli.html#QNLI
305

306
    URL = "https://dl.fbaipublicfiles.com/glue/data/QNLIv2.zip"
307
    MD5 = "b4efd6554440de1712e9b54e14760e82"
308

309
    NUM_LINES = {
310
        "train": 104743,
311
        "dev": 5463,
312
        "test": 5463,
313
    }
314

315
    _PATH = "QNLIv2.zip"
316

317
    DATASET_NAME = "QNLI"
318

319
    _EXTRACTED_FILES = {
320
        "train": "train.tsv",
321
        "dev": "dev.tsv",
322
        "test": "test.tsv",
323
    }
324

325
    MAP_LABELS = {"entailment": 0, "not_entailment": 1}
326

327
    def __init__(self, root, split, max_length=128):
328

329
        self.root = root
330
        self.split = split
331
        if os.path.exists(self.root):
332
            assert os.path.isdir(self.root)
333
        else:
334
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
335
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
336

337
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
338
        assert os.path.exists(self.path), f"{self.path} is not exists!"
339
        self.max_length = max_length
340

341
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
342

343
        assert split in ["train", "dev", "test"]
344

345
        def _modify_res(x):
346
            if split == "test":
347
                # test split for QNLI doesn't have labels
348
                return (x[1], x[2])
349
            else:
350
                return (x[1], x[2], self.MAP_LABELS[x[3]])
351

352
        self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res)
353

354
    def __getitem__(self, idx):
355
        sample = self.samples[idx]
356

357
        encoded_inputs = self.tokenizer(
358
            sample[0],
359
            text_pair=sample[1],
360
            padding="max_length",
361
            truncation="longest_first",
362
            max_length=self.max_length,
363
            return_token_type_ids=False,
364
        )
365
        input_ids = encoded_inputs["input_ids"]
366
        input_ids = paddle.to_tensor(input_ids)
367
        if self.split != "test":
368
            return input_ids, sample[2]
369
        else:
370
            return input_ids
371

372
    def __len__(self):
373
        return len(self.samples)
374

375
    @property
376
    def class_num(self):
377
        return 2
378

379

380
class RTE(paddle.io.Dataset):
381
    """The Recognizing Textual Entailment (RTE) datasets come from a series of annual textual
382
    entailment challenges. We combine the data from RTE1 (Dagan et al., 2006), RTE2 (Bar Haim
383
    et al., 2006), RTE3 (Giampiccolo et al., 2007), and RTE5 (Bentivogli et al., 2009).4 Examples are
384
    constructed based on news and Wikipedia text. We convert all datasets to a two-class split, where
385
    for three-class datasets we collapse neutral and contradiction into not entailment, for consistency."""
386

387
    # ref https://pytorch.org/text/stable/_modules/torchtext/datasets/rte.html#RTE
388

389
    URL = "https://dl.fbaipublicfiles.com/glue/data/RTE.zip"
390
    MD5 = "bef554d0cafd4ab6743488101c638539"
391

392
    NUM_LINES = {
393
        "train": 67349,
394
        "dev": 872,
395
        "test": 1821,
396
    }
397

398
    _PATH = "RTE.zip"
399

400
    DATASET_NAME = "RTE"
401

402
    _EXTRACTED_FILES = {
403
        "train": "train.tsv",
404
        "dev": "dev.tsv",
405
        "test": "test.tsv",
406
    }
407

408
    MAP_LABELS = {"entailment": 0, "not_entailment": 1}
409

410
    def __init__(self, root, split, max_length=128):
411

412
        self.root = root
413
        self.split = split
414
        if os.path.exists(self.root):
415
            assert os.path.isdir(self.root)
416
        else:
417
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
418
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
419

420
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
421
        assert os.path.exists(self.path), f"{self.path} is not exists!"
422
        self.max_length = max_length
423

424
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
425

426
        assert split in ["train", "dev", "test"]
427

428
        def _modify_res(x):
429
            if split == "test":
430
                # test split for RTE doesn't have labels
431
                return (x[1], x[2])
432
            else:
433
                return (x[1], x[2], self.MAP_LABELS[x[3]])
434

435
        self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res)
436

437
    def __getitem__(self, idx):
438
        sample = self.samples[idx]
439

440
        encoded_inputs = self.tokenizer(
441
            sample[0],
442
            text_pair=sample[1],
443
            padding="max_length",
444
            truncation="longest_first",
445
            max_length=self.max_length,
446
            return_token_type_ids=False,
447
        )
448
        input_ids = encoded_inputs["input_ids"]
449
        input_ids = paddle.to_tensor(input_ids)
450
        if self.split != "test":
451
            return input_ids, sample[2]
452
        else:
453
            return input_ids
454

455
    def __len__(self):
456
        return len(self.samples)
457

458
    @property
459
    def class_num(self):
460
        return 2
461

462

463
class WNLI(paddle.io.Dataset):
464
    """The Winograd Schema Challenge (Levesque et al., 2011) is a reading comprehension task
465
    in which a system must read a sentence with a pronoun and select the referent of that pronoun from
466
    a list of choices. The examples are manually constructed to foil simple statistical methods: Each
467
    one is contingent on contextual information provided by a single word or phrase in the sentence.
468
    To convert the problem into sentence pair classification, we construct sentence pairs by replacing
469
    the ambiguous pronoun with each possible referent. The task is to predict if the sentence with the
470
    pronoun substituted is entailed by the original sentence. We use a small evaluation set consisting of
471
    new examples derived from fiction books that was shared privately by the authors of the original
472
    corpus. While the included training set is balanced between two classes, the test set is imbalanced
473
    between them (65% not entailment). Also, due to a data quirk, the development set is adversarial:
474
    hypotheses are sometimes shared between training and development examples, so if a model memorizes the
475
    training examples, they will predict the wrong label on corresponding development set
476
    example. As with QNLI, each example is evaluated separately, so there is not a systematic correspondence
477
    between a model's score on this task and its score on the unconverted original task. We
478
    call converted dataset WNLI (Winograd NLI)."""
479

480
    # ref https://pytorch.org/text/stable/_modules/torchtext/datasets/wnli.html#WNLI
481

482
    URL = "https://dl.fbaipublicfiles.com/glue/data/WNLI.zip"
483
    MD5 = "a1b4bd2861017d302d29e42139657a42"
484

485
    NUM_LINES = {
486
        "train": 635,
487
        "dev": 71,
488
        "test": 146,
489
    }
490

491
    _PATH = "WNLI.zip"
492

493
    DATASET_NAME = "WNLI"
494

495
    _EXTRACTED_FILES = {
496
        "train": "train.tsv",
497
        "dev": "dev.tsv",
498
        "test": "test.tsv",
499
    }
500

501
    def __init__(self, root, split, max_length=128):
502

503
        self.root = root
504
        self.split = split
505
        if os.path.exists(self.root):
506
            assert os.path.isdir(self.root)
507
        else:
508
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
509
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
510

511
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
512
        assert os.path.exists(self.path), f"{self.path} is not exists!"
513
        self.max_length = max_length
514

515
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
516

517
        assert split in ["train", "dev", "test"]
518

519
        def _modify_res(x):
520
            if split == "test":
521
                # test split for WNLI doesn't have labels
522
                return (x[1], x[2])
523
            else:
524
                return (x[1], x[2], int(x[3]))
525

526
        self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res)
527

528
    def __getitem__(self, idx):
529
        sample = self.samples[idx]
530

531
        encoded_inputs = self.tokenizer(
532
            sample[0],
533
            text_pair=sample[1],
534
            padding="max_length",
535
            truncation="longest_first",
536
            max_length=self.max_length,
537
            return_token_type_ids=False,
538
        )
539
        input_ids = encoded_inputs["input_ids"]
540
        input_ids = paddle.to_tensor(input_ids)
541
        if self.split != "test":
542
            return input_ids, sample[2]
543
        else:
544
            return input_ids
545

546
    def __len__(self):
547
        return len(self.samples)
548

549
    @property
550
    def class_num(self):
551
        return 2
552

553

554
class MRPC(paddle.io.Dataset):
555
    """The Microsoft Research Paraphrase Corpus (Dolan & Brockett, 2005) is a corpus of
556
    sentence pairs automatically extracted from online news sources, with human annotations
557
    for whether the sentences in the pair are semantically equivalent."""
558

559
    # ref https://pytorch.org/text/stable/_modules/torchtext/datasets/mrpc.html#MRPC
560

561
    URL = {
562
        "train": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_train.txt",
563
        "test": "https://dl.fbaipublicfiles.com/senteval/senteval_data/msr_paraphrase_test.txt",
564
    }
565

566
    MD5 = {
567
        "train": "793daf7b6224281e75fe61c1f80afe35",
568
        "test": "e437fdddb92535b820fe8852e2df8a49",
569
    }
570

571
    NUM_LINES = {
572
        "train": 4076,
573
        "test": 1725,
574
    }
575

576
    DATASET_NAME = "MRPC"
577

578
    _EXTRACTED_FILES = {
579
        "train": "msr_paraphrase_train.txt",
580
        "test": "msr_paraphrase_test.txt",
581
    }
582

583
    def __init__(self, root, split, max_length=128):
584

585
        self.root = root
586
        self.split = split
587
        if os.path.exists(self.root):
588
            assert os.path.isdir(self.root)
589
        cached_path(self.URL[split], cache_dir=os.path.abspath(self.root))
590

591
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
592
        assert os.path.exists(self.path), f"{self.path} is not exists!"
593
        self.max_length = max_length
594

595
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
596

597
        assert split in ["train", "test"]
598

599
        def _modify_res(x):
600
            return (x[3], x[4], int(x[0]))
601

602
        self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res)
603

604
    def __getitem__(self, idx):
605
        sample = self.samples[idx]
606

607
        encoded_inputs = self.tokenizer(
608
            sample[0],
609
            text_pair=sample[1],
610
            padding="max_length",
611
            truncation="longest_first",
612
            max_length=self.max_length,
613
            return_token_type_ids=False,
614
        )
615
        input_ids = encoded_inputs["input_ids"]
616
        input_ids = paddle.to_tensor(input_ids)
617
        return input_ids, sample[2]
618

619
    def __len__(self):
620
        return len(self.samples)
621

622
    @property
623
    def class_num(self):
624
        return 2
625

626

627
class QQP(paddle.io.Dataset):
628
    """The Quora Question Pairs2 dataset is a collection of question pairs from the
629
    community question-answering website Quora. The task is to determine whether a
630
    pair of questions are semantically equivalent."""
631

632
    # ref https://huggingface.co/datasets/glue/blob/main/glue.py#L212-L239
633

634
    URL = "https://dl.fbaipublicfiles.com/glue/data/QQP-clean.zip"
635
    MD5 = "884bf26e39c783d757acc510a2a516ef"
636

637
    NUM_LINES = {
638
        "train": 363846,
639
        "dev": 40430,
640
        "test": 390961,
641
    }
642

643
    _PATH = "QQP-clean.zip"
644

645
    DATASET_NAME = "QQP"
646

647
    _EXTRACTED_FILES = {
648
        "train": "train.tsv",
649
        "dev": "dev.tsv",
650
        "test": "test.tsv",
651
    }
652

653
    MAP_LABELS = {"not_duplicate": 0, "duplicate": 1}
654

655
    def __init__(self, root, split, max_length=128):
656

657
        self.root = root
658
        self.split = split
659
        if os.path.exists(self.root):
660
            assert os.path.isdir(self.root)
661
        else:
662
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
663
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
664

665
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
666
        assert os.path.exists(self.path), f"{self.path} is not exists!"
667
        self.max_length = max_length
668

669
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
670

671
        assert split in ["train", "dev", "test"]
672

673
        def _modify_res(x):
674
            if split == "test":
675
                # test split for QQP doesn't have labels
676
                return (x[1], x[2])
677
            else:
678
                return (x[3], x[4], int(x[5]))
679

680
        self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res)
681

682
    def __getitem__(self, idx):
683
        sample = self.samples[idx]
684

685
        encoded_inputs = self.tokenizer(
686
            sample[0],
687
            text_pair=sample[1],
688
            padding="max_length",
689
            truncation="longest_first",
690
            max_length=self.max_length,
691
            return_token_type_ids=False,
692
        )
693
        input_ids = encoded_inputs["input_ids"]
694
        input_ids = paddle.to_tensor(input_ids)
695
        if self.split != "test":
696
            return input_ids, sample[2]
697
        else:
698
            return input_ids
699

700
    def __len__(self):
701
        return len(self.samples)
702

703
    @property
704
    def class_num(self):
705
        return 2
706

707

708
class STSB(paddle.io.Dataset):
709
    """The Semantic Textual Similarity Benchmark (Cer et al., 2017) is a collection of
710
    sentence pairs drawn from news headlines, video and image captions, and natural
711
    language inference data. Each pair is human-annotated with a similarity score
712
    from 1 to 5."""
713

714
    # ref https://huggingface.co/datasets/glue/blob/main/glue.py#L240-L267
715

716
    URL = "https://dl.fbaipublicfiles.com/glue/data/STS-B.zip"
717
    MD5 = "d573676be38f1a075a5702b90ceab3de"
718

719
    NUM_LINES = {
720
        "train": 5749,
721
        "dev": 1500,
722
        "test": 1379,
723
    }
724

725
    _PATH = "STS-B.zip"
726

727
    DATASET_NAME = "STSB"
728

729
    _EXTRACTED_FILES = {
730
        "train": "train.tsv",
731
        "dev": "dev.tsv",
732
        "test": "test.tsv",
733
    }
734

735
    def __init__(self, root, split, max_length=128):
736

737
        self.root = root
738
        self.split = split
739
        if os.path.exists(self.root):
740
            assert os.path.isdir(self.root)
741
        else:
742
            zip_path = cached_path(self.URL, cache_dir=os.path.abspath(self.root))
743
            unzip(zip_path, mode="r", out_dir=os.path.join(self.root, ".."), delete=True)
744

745
        self.path = os.path.join(self.root, self._EXTRACTED_FILES[split])
746
        assert os.path.exists(self.path), f"{self.path} is not exists!"
747
        self.max_length = max_length
748

749
        self.tokenizer = GPTTokenizer.from_pretrained("gpt2")
750

751
        assert split in ["train", "dev", "test"]
752

753
        def _modify_res(x):
754
            if split == "test":
755
                # test split for STSB doesn't have labels
756
                return (x[7], x[8])
757
            else:
758
                return (x[7], x[8], float(x[9]))
759

760
        self.samples = parse_csv(self.path, skip_lines=1, delimiter="\t", map_funcs=_modify_res)
761

762
    def __getitem__(self, idx):
763
        sample = self.samples[idx]
764

765
        encoded_inputs = self.tokenizer(
766
            sample[0],
767
            text_pair=sample[1],
768
            padding="max_length",
769
            truncation="longest_first",
770
            max_length=self.max_length,
771
            return_token_type_ids=False,
772
        )
773
        input_ids = encoded_inputs["input_ids"]
774
        input_ids = paddle.to_tensor(input_ids)
775
        if self.split != "test":
776
            # Note(GuoxiaWang): We need return shape [1] value,
777
            # so that we can attain a batched label with shape [batchsize, 1].
778
            # Because the logits shape is [batchsize, 1], and feed into MSE loss.
779
            return input_ids, np.array([sample[2]], dtype=np.float32)
780
        else:
781
            return input_ids
782

783
    def __len__(self):
784
        return len(self.samples)
785

786
    @property
787
    def class_num(self):
788
        return 2
789

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.