CSS-LM

Форк
0
/
tokenization_transfo_xl.py 
795 строк · 29.0 Кб
1
# coding=utf-8
2
# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
4
#
5
# Licensed under the Apache License, Version 2.0 (the "License");
6
# you may not use this file except in compliance with the License.
7
# You may obtain a copy of the License at
8
#
9
#     http://www.apache.org/licenses/LICENSE-2.0
10
#
11
# Unless required by applicable law or agreed to in writing, software
12
# distributed under the License is distributed on an "AS IS" BASIS,
13
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
# See the License for the specific language governing permissions and
15
# limitations under the License.
16
""" Tokenization classes for Transformer XL model.
17
    Adapted from https://github.com/kimiyoung/transformer-xl.
18
"""
19

20

21
import glob
22
import logging
23
import os
24
import pickle
25
import re
26
from collections import Counter, OrderedDict
27
from typing import Optional
28

29
import numpy as np
30
from tokenizers import Tokenizer
31
from tokenizers.implementations import BaseTokenizer
32
from tokenizers.models import WordLevel
33
from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalizer_from_str
34
from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
35
from tokenizers.processors import BertProcessing
36

37
from .file_utils import cached_path, is_torch_available
38
from .tokenization_utils import PreTrainedTokenizer
39
from .tokenization_utils_fast import PreTrainedTokenizerFast
40

41

42
if is_torch_available():
43
    import torch
44

45

46
logger = logging.getLogger(__name__)
47

48
VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
49
VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
50

51
PRETRAINED_VOCAB_FILES_MAP = {
52
    "pretrained_vocab_file": {
53
        "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
54
    }
55
}
56

57
PRETRAINED_VOCAB_FILES_MAP_FAST = {
58
    "pretrained_vocab_file": {
59
        "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.json",
60
    }
61
}
62

63
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
64
    "transfo-xl-wt103": None,
65
}
66

67
PRETRAINED_CORPUS_ARCHIVE_MAP = {
68
    "transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
69
}
70
CORPUS_NAME = "corpus.bin"
71

72

73
class TransfoXLTokenizer(PreTrainedTokenizer):
74
    """
75
    Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
76

77
    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
78
    should refer to the superclass for more information regarding methods.
79
    """
80

81
    vocab_files_names = VOCAB_FILES_NAMES
82
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
83
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
84
    model_input_names = []
85

86
    def __init__(
87
        self,
88
        special=None,
89
        min_freq=0,
90
        max_size=None,
91
        lower_case=False,
92
        delimiter=None,
93
        vocab_file=None,
94
        pretrained_vocab_file=None,
95
        never_split=None,
96
        unk_token="<unk>",
97
        eos_token="<eos>",
98
        additional_special_tokens=["<formula>"],
99
        **kwargs
100
    ):
101
        super().__init__(
102
            unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
103
        )
104

105
        if never_split is None:
106
            never_split = self.all_special_tokens
107
        if special is None:
108
            special = []
109
        self.counter = Counter()
110
        self.special = special
111
        self.min_freq = min_freq
112
        self.max_size = max_size
113
        self.lower_case = lower_case
114
        self.delimiter = delimiter
115
        self.vocab_file = vocab_file
116
        self.never_split = never_split
117
        self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
118
        self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
119
        self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
120

121
        try:
122
            if pretrained_vocab_file is not None:
123
                # Hack because, honestly this tokenizer was not made to be used
124
                # in a library like ours, at all.
125
                vocab_dict = torch.load(pretrained_vocab_file)
126
                for key, value in vocab_dict.items():
127
                    if key not in self.__dict__:
128
                        self.__dict__[key] = value
129

130
            if vocab_file is not None:
131
                self.build_vocab()
132
        except Exception:
133
            raise ValueError(
134
                "Unable to parse file {}. Unknown format. "
135
                "If you tried to load a model saved through TransfoXLTokenizerFast,"
136
                "please note they are not compatible.".format(pretrained_vocab_file)
137
            )
138

139
        if vocab_file is not None:
140
            self.build_vocab()
141

142
    def _compile_space_around_punctuation_pattern(self):
143
        look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
144
        look_ahead_to_match_all_except_space = r"(?=[^\s])"
145
        return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
146

147
    def count_file(self, path, verbose=False, add_eos=False):
148
        if verbose:
149
            logger.info("counting file {} ...".format(path))
150
        assert os.path.exists(path)
151

152
        sents = []
153
        with open(path, "r", encoding="utf-8") as f:
154
            for idx, line in enumerate(f):
155
                if verbose and idx > 0 and idx % 500000 == 0:
156
                    logger.info("    line {}".format(idx))
157
                symbols = self.tokenize(line, add_eos=add_eos)
158
                self.counter.update(symbols)
159
                sents.append(symbols)
160

161
        return sents
162

163
    def count_sents(self, sents, verbose=False):
164
        """
165
            sents : a list of sentences, each a list of tokenized symbols
166
        """
167
        if verbose:
168
            logger.info("counting {} sents ...".format(len(sents)))
169
        for idx, symbols in enumerate(sents):
170
            if verbose and idx > 0 and idx % 500000 == 0:
171
                logger.info("    line {}".format(idx))
172
            self.counter.update(symbols)
173

174
    def _build_from_file(self, vocab_file):
175
        self.idx2sym = []
176
        self.sym2idx = OrderedDict()
177

178
        with open(vocab_file, "r", encoding="utf-8") as f:
179
            for line in f:
180
                symb = line.strip().split()[0]
181
                self.add_symbol(symb)
182
        if "<UNK>" in self.sym2idx:
183
            self.unk_idx = self.sym2idx["<UNK>"]
184
        elif "<unk>" in self.sym2idx:
185
            self.unk_idx = self.sym2idx["<unk>"]
186
        else:
187
            raise ValueError("No <unkown> token in vocabulary")
188

189
    def save_vocabulary(self, vocab_path):
190
        """
191
        Save the vocabulary and special tokens file to a directory.
192

193
        Args:
194
            vocab_path (:obj:`str`):
195
                The directory in which to save the vocabulary.
196

197
        Returns:
198
            :obj:`Tuple(str)`: Paths to the files saved.
199
        """
200

201
        logger.warning(
202
            "Please note you will not be able to load the save vocabulary in"
203
            " Rust-based TransfoXLTokenizerFast as they don't share the same structure."
204
        )
205

206
        if os.path.isdir(vocab_path):
207
            vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
208
        else:
209
            vocab_file = vocab_path
210
        torch.save(self.__dict__, vocab_file)
211
        return (vocab_file,)
212

213
    def build_vocab(self):
214
        if self.vocab_file:
215
            logger.info("building vocab from {}".format(self.vocab_file))
216
            self._build_from_file(self.vocab_file)
217
            logger.info("final vocab size {}".format(len(self)))
218
        else:
219
            logger.info("building vocab with min_freq={}, max_size={}".format(self.min_freq, self.max_size))
220
            self.idx2sym = []
221
            self.sym2idx = OrderedDict()
222

223
            for sym in self.special:
224
                self.add_special(sym)
225

226
            for sym, cnt in self.counter.most_common(self.max_size):
227
                if cnt < self.min_freq:
228
                    break
229
                self.add_symbol(sym)
230

231
            logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
232

233
    def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
234
        if verbose:
235
            logger.info("encoding file {} ...".format(path))
236
        assert os.path.exists(path)
237
        encoded = []
238
        with open(path, "r", encoding="utf-8") as f:
239
            for idx, line in enumerate(f):
240
                if verbose and idx > 0 and idx % 500000 == 0:
241
                    logger.info("    line {}".format(idx))
242
                symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos)
243
                encoded.append(self.convert_to_tensor(symbols))
244

245
        if ordered:
246
            encoded = torch.cat(encoded)
247

248
        return encoded
249

250
    def encode_sents(self, sents, ordered=False, verbose=False):
251
        if verbose:
252
            logger.info("encoding {} sents ...".format(len(sents)))
253
        encoded = []
254
        for idx, symbols in enumerate(sents):
255
            if verbose and idx > 0 and idx % 500000 == 0:
256
                logger.info("    line {}".format(idx))
257
            encoded.append(self.convert_to_tensor(symbols))
258

259
        if ordered:
260
            encoded = torch.cat(encoded)
261

262
        return encoded
263

264
    def add_special(self, sym):
265
        if sym not in self.sym2idx:
266
            self.idx2sym.append(sym)
267
            self.sym2idx[sym] = len(self.idx2sym) - 1
268
            setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym])
269

270
    def add_symbol(self, sym):
271
        if sym not in self.sym2idx:
272
            self.idx2sym.append(sym)
273
            self.sym2idx[sym] = len(self.idx2sym) - 1
274

275
    def move_added_token(self, token: str, target_idx: int):
276
        """
277
        Moves an added token to a specific position in the vocab.
278
        This method should be used when resizing an embedding layer other than the last one in the `AdaptiveEmbedding`
279
        in order to move the token in the tokenizer from the default position (at the very end) to the desired one.
280

281
        Args:
282
            token: The token to move to a specific position in the vocab.
283
            target_idx: The position where the token should be moved to.
284
        """
285
        assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token"
286
        assert token not in self.idx2sym, "Token which should be moved is already in vocab"
287

288
        # Insert sym into vocab
289
        self.idx2sym.insert(target_idx, token)
290
        self.sym2idx[token] = target_idx
291

292
        # Shift following indices in sym2idx
293
        for idx in range(target_idx + 1, len(self.idx2sym)):
294
            current_sym = self.idx2sym[idx]
295
            self.sym2idx[current_sym] = idx
296

297
        # Delete token from added_tokens
298
        old_index = self.added_tokens_encoder[token]
299
        del self.added_tokens_decoder[old_index]
300
        del self.added_tokens_encoder[token]
301

302
    def _convert_id_to_token(self, idx):
303
        """Converts an id in a token (BPE) using the vocab."""
304
        assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
305
        return self.idx2sym[idx]
306

307
    def _convert_token_to_id(self, sym):
308
        """ Converts a token (str) in an id using the vocab. """
309
        if sym in self.sym2idx:
310
            return self.sym2idx[sym]
311
        else:
312
            # logger.info('encounter unk {}'.format(sym))
313
            # assert '<eos>' not in sym
314
            if hasattr(self, "unk_idx"):
315
                return self.sym2idx.get(sym, self.unk_idx)
316
            # Backward compatibility with pre-trained models
317
            elif "<unk>" in self.sym2idx:
318
                return self.sym2idx["<unk>"]
319
            elif "<UNK>" in self.sym2idx:
320
                return self.sym2idx["<UNK>"]
321
            else:
322
                raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
323

324
    def convert_tokens_to_string(self, tokens):
325
        """ Converts a sequence of tokens (string) in a single string. """
326
        out_string = " ".join(tokens).strip()
327
        return out_string
328

329
    def convert_to_tensor(self, symbols):
330
        return torch.LongTensor(self.convert_tokens_to_ids(symbols))
331

332
    @property
333
    def vocab_size(self):
334
        return len(self.idx2sym)
335

336
    def get_vocab(self):
337
        return dict(self.sym2idx, **self.added_tokens_encoder)
338

339
    def _tokenize(self, line, add_eos=False, add_double_eos=False):
340
        line = line.strip()
341
        # convert to lower case
342
        if self.lower_case:
343
            line = line.lower()
344

345
        # empty delimiter '' will evaluate False
346
        if self.delimiter == "":
347
            symbols = line
348
        else:
349
            symbols = line.split(self.delimiter)
350

351
        if add_double_eos:  # lm1b
352
            return ["<S>"] + symbols + ["<S>"]
353
        elif add_eos:
354
            return symbols + ["<eos>"]
355
        else:
356
            return symbols
357

358
    def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
359
        # add spaces before punctuation symbols as should be done in transfo-xl
360
        add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False)
361
        if add_space_before_punct_symbol:
362
            text = self.punctuation_with_space_around_pattern.sub(r" ", text)
363
        elif self.punction_without_space_before_pattern.search(text):
364
            # searches until the first occurence of a punctuation symbol without surrounding spaces
365
            logger.warning(
366
                "You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
367
            )
368

369
        return (text, kwargs)
370

371

372
class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
373
    def __init__(
374
        self,
375
        vocab_file,
376
        delimiter,
377
        lowercase,
378
        unk_token,
379
        eos_token,
380
        add_eos=False,
381
        add_double_eos=False,
382
        normalization: Optional[str] = None,
383
    ):
384

385
        try:
386
            tokenizer = WordLevel(vocab_file, unk_token=unk_token)
387
            tokenizer = Tokenizer(tokenizer)
388
        except Exception:
389
            raise ValueError(
390
                "Unable to parse file {}. Unknown format. "
391
                "If you tried to load a model saved through TransfoXLTokenizer,"
392
                "please note they are not compatible.".format(vocab_file)
393
            )
394

395
        # Create the correct normalization path
396
        normalizer = []
397

398
        # Include unicode normalization
399
        if normalization:
400
            normalizer += [unicode_normalizer_from_str(normalization)]
401

402
        # Include case normalization
403
        if lowercase:
404
            normalizer += [Lowercase()]
405

406
        # Strip normalizer at the end
407
        normalizer += [Strip(left=True, right=True)]
408

409
        if len(normalizer) > 0:
410
            tokenizer.normalizer = Sequence(normalizer) if len(normalizer) > 1 else normalizer[0]
411

412
        # Setup the splitter
413
        tokenizer.pre_tokenizer = CharDelimiterSplit(delimiter) if delimiter else WhitespaceSplit()
414

415
        if add_double_eos:
416
            tokenizer.post_processor = BertProcessing(
417
                (eos_token, tokenizer.token_to_id(eos_token)), (eos_token, tokenizer.token_to_id(eos_token))
418
            )
419

420
        parameters = {
421
            "model": "TransfoXLModel",
422
            "add_eos": add_eos,
423
            "add_double_eos": add_double_eos,
424
            "unk_token": unk_token,
425
            "eos_token": eos_token,
426
            "delimiter": delimiter,
427
            "lowercase": lowercase,
428
        }
429

430
        super().__init__(tokenizer, parameters)
431

432

433
class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
434
    """
435
    Construct a "Fast" Transformer-XL tokenizer (backed by HuggingFace's `tokenizers` library).
436

437
    The Transformer-XL tokenizer is a word-level tokenizer (no sub-word tokenization).
438

439
    Adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
440

441
    This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users
442
    should refer to the superclass for more information regarding methods.
443
    """
444

445
    vocab_files_names = VOCAB_FILES_NAMES_FAST
446
    pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST
447
    max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
448
    model_input_names = []
449

450
    def __init__(
451
        self,
452
        special=None,
453
        min_freq=0,
454
        max_size=None,
455
        lower_case=False,
456
        delimiter=None,
457
        vocab_file=None,
458
        pretrained_vocab_file=None,
459
        never_split=None,
460
        unk_token="<unk>",
461
        eos_token="<eos>",
462
        additional_special_tokens=["<formula>"],
463
        add_eos=False,
464
        add_double_eos=False,
465
        normalization=None,
466
        **kwargs
467
    ):
468

469
        super().__init__(
470
            _TransfoXLDelimiterLookupTokenizer(
471
                vocab_file=vocab_file or pretrained_vocab_file,
472
                delimiter=delimiter,
473
                lowercase=lower_case,
474
                unk_token=unk_token,
475
                eos_token=eos_token,
476
                add_eos=add_eos,
477
                add_double_eos=add_double_eos,
478
                normalization=normalization,
479
            ),
480
            unk_token=unk_token,
481
            eos_token=eos_token,
482
            additional_special_tokens=additional_special_tokens,
483
            **kwargs,
484
        )
485

486
    def save_pretrained(self, save_directory):
487
        logger.warning(
488
            "Please note you will not be able to load the vocabulary in"
489
            " Python-based TransfoXLTokenizer as they don't share the same structure."
490
        )
491

492
        return super().save_pretrained(save_directory)
493

494

495
class LMOrderedIterator(object):
496
    def __init__(self, data, bsz, bptt, device="cpu", ext_len=None):
497
        """
498
            data -- LongTensor -- the LongTensor is strictly ordered
499
        """
500
        self.bsz = bsz
501
        self.bptt = bptt
502
        self.ext_len = ext_len if ext_len is not None else 0
503

504
        self.device = device
505

506
        # Work out how cleanly we can divide the dataset into bsz parts.
507
        self.n_step = data.size(0) // bsz
508

509
        # Trim off any extra elements that wouldn't cleanly fit (remainders).
510
        data = data.narrow(0, 0, self.n_step * bsz)
511

512
        # Evenly divide the data across the bsz batches.
513
        self.data = data.view(bsz, -1).t().contiguous().to(device)
514

515
        # Number of mini-batches
516
        self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
517

518
    def get_batch(self, i, bptt=None):
519
        if bptt is None:
520
            bptt = self.bptt
521
        seq_len = min(bptt, self.data.size(0) - 1 - i)
522

523
        end_idx = i + seq_len
524
        beg_idx = max(0, i - self.ext_len)
525

526
        data = self.data[beg_idx:end_idx]
527
        target = self.data[i + 1 : i + 1 + seq_len]
528

529
        data_out = data.transpose(0, 1).contiguous().to(self.device)
530
        target_out = target.transpose(0, 1).contiguous().to(self.device)
531

532
        return data_out, target_out, seq_len
533

534
    def get_fixlen_iter(self, start=0):
535
        for i in range(start, self.data.size(0) - 1, self.bptt):
536
            yield self.get_batch(i)
537

538
    def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
539
        max_len = self.bptt + max_deviation * std
540
        i = start
541
        while True:
542
            bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0
543
            bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
544
            data, target, seq_len = self.get_batch(i, bptt)
545
            i += seq_len
546
            yield data, target, seq_len
547
            if i >= self.data.size(0) - 2:
548
                break
549

550
    def __iter__(self):
551
        return self.get_fixlen_iter()
552

553

554
class LMShuffledIterator(object):
555
    def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
556
        """
557
            data -- list[LongTensor] -- there is no order among the LongTensors
558
        """
559
        self.data = data
560

561
        self.bsz = bsz
562
        self.bptt = bptt
563
        self.ext_len = ext_len if ext_len is not None else 0
564

565
        self.device = device
566
        self.shuffle = shuffle
567

568
    def get_sent_stream(self):
569
        # index iterator
570
        epoch_indices = np.random.permutation(len(self.data)) if self.shuffle else np.array(range(len(self.data)))
571

572
        # sentence iterator
573
        for idx in epoch_indices:
574
            yield self.data[idx]
575

576
    def stream_iterator(self, sent_stream):
577
        # streams for each data in the batch
578
        streams = [None] * self.bsz
579

580
        data = torch.LongTensor(self.bptt, self.bsz)
581
        target = torch.LongTensor(self.bptt, self.bsz)
582

583
        n_retain = 0
584

585
        while True:
586
            # data   : [n_retain+bptt x bsz]
587
            # target : [bptt x bsz]
588
            data[n_retain:].fill_(-1)
589
            target.fill_(-1)
590

591
            valid_batch = True
592

593
            for i in range(self.bsz):
594
                n_filled = 0
595
                try:
596
                    while n_filled < self.bptt:
597
                        if streams[i] is None or len(streams[i]) <= 1:
598
                            streams[i] = next(sent_stream)
599
                        # number of new tokens to fill in
600
                        n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
601
                        # first n_retain tokens are retained from last batch
602
                        data[n_retain + n_filled : n_retain + n_filled + n_new, i] = streams[i][:n_new]
603
                        target[n_filled : n_filled + n_new, i] = streams[i][1 : n_new + 1]
604
                        streams[i] = streams[i][n_new:]
605
                        n_filled += n_new
606
                except StopIteration:
607
                    valid_batch = False
608
                    break
609

610
            if not valid_batch:
611
                return
612

613
            data_out = data.transpose(0, 1).contiguous().to(self.device)
614
            target_out = target.transpose(0, 1).contiguous().to(self.device)
615

616
            yield data_out, target_out, self.bptt
617

618
            n_retain = min(data.size(0), self.ext_len)
619
            if n_retain > 0:
620
                data[:n_retain] = data[-n_retain:]
621
            data.resize_(n_retain + self.bptt, data.size(1))
622

623
    def __iter__(self):
624
        # sent_stream is an iterator
625
        sent_stream = self.get_sent_stream()
626

627
        for batch in self.stream_iterator(sent_stream):
628
            yield batch
629

630

631
class LMMultiFileIterator(LMShuffledIterator):
632
    def __init__(self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
633

634
        self.paths = paths
635
        self.vocab = vocab
636

637
        self.bsz = bsz
638
        self.bptt = bptt
639
        self.ext_len = ext_len if ext_len is not None else 0
640

641
        self.device = device
642
        self.shuffle = shuffle
643

644
    def get_sent_stream(self, path):
645
        sents = self.vocab.encode_file(path, add_double_eos=True)
646
        if self.shuffle:
647
            np.random.shuffle(sents)
648
        sent_stream = iter(sents)
649

650
        return sent_stream
651

652
    def __iter__(self):
653
        if self.shuffle:
654
            np.random.shuffle(self.paths)
655

656
        for path in self.paths:
657
            # sent_stream is an iterator
658
            sent_stream = self.get_sent_stream(path)
659
            for batch in self.stream_iterator(sent_stream):
660
                yield batch
661

662

663
class TransfoXLCorpus(object):
664
    @classmethod
665
    def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
666
        """
667
        Instantiate a pre-processed corpus.
668
        """
669
        vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
670
        if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
671
            corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
672
        else:
673
            corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
674
        # redirect to the cache, if necessary
675
        try:
676
            resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
677
        except EnvironmentError:
678
            logger.error(
679
                "Corpus '{}' was not found in corpus list ({}). "
680
                "We assumed '{}' was a path or url but couldn't find files {} "
681
                "at this path or url.".format(
682
                    pretrained_model_name_or_path,
683
                    ", ".join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
684
                    pretrained_model_name_or_path,
685
                    corpus_file,
686
                )
687
            )
688
            return None
689
        if resolved_corpus_file == corpus_file:
690
            logger.info("loading corpus file {}".format(corpus_file))
691
        else:
692
            logger.info("loading corpus file {} from cache at {}".format(corpus_file, resolved_corpus_file))
693

694
        # Instantiate tokenizer.
695
        corpus = cls(*inputs, **kwargs)
696
        corpus_dict = torch.load(resolved_corpus_file)
697
        for key, value in corpus_dict.items():
698
            corpus.__dict__[key] = value
699
        corpus.vocab = vocab
700
        if corpus.train is not None:
701
            corpus.train = torch.tensor(corpus.train, dtype=torch.long)
702
        if corpus.valid is not None:
703
            corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)
704
        if corpus.test is not None:
705
            corpus.test = torch.tensor(corpus.test, dtype=torch.long)
706
        return corpus
707

708
    def __init__(self, *args, **kwargs):
709
        self.vocab = TransfoXLTokenizer(*args, **kwargs)
710
        self.dataset = None
711
        self.train = None
712
        self.valid = None
713
        self.test = None
714

715
    def build_corpus(self, path, dataset):
716
        self.dataset = dataset
717

718
        if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
719
            self.vocab.count_file(os.path.join(path, "train.txt"))
720
            self.vocab.count_file(os.path.join(path, "valid.txt"))
721
            self.vocab.count_file(os.path.join(path, "test.txt"))
722
        elif self.dataset == "wt103":
723
            self.vocab.count_file(os.path.join(path, "train.txt"))
724
        elif self.dataset == "lm1b":
725
            train_path_pattern = os.path.join(
726
                path,
727
                "1-billion-word-language-modeling-benchmark-r13output",
728
                "training-monolingual.tokenized.shuffled",
729
                "news.en-*",
730
            )
731
            train_paths = glob.glob(train_path_pattern)
732
            # the vocab will load from file when build_vocab() is called
733

734
        self.vocab.build_vocab()
735

736
        if self.dataset in ["ptb", "wt2", "wt103"]:
737
            self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
738
            self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
739
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)
740
        elif self.dataset in ["enwik8", "text8"]:
741
            self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True, add_eos=False)
742
            self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=False)
743
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False)
744
        elif self.dataset == "lm1b":
745
            self.train = train_paths
746
            self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True)
747
            self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=False, add_double_eos=True)
748

749
    def get_iterator(self, split, *args, **kwargs):
750
        if split == "train":
751
            if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
752
                data_iter = LMOrderedIterator(self.train, *args, **kwargs)
753
            elif self.dataset == "lm1b":
754
                kwargs["shuffle"] = True
755
                data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
756
        elif split in ["valid", "test"]:
757
            data = self.valid if split == "valid" else self.test
758
            if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
759
                data_iter = LMOrderedIterator(data, *args, **kwargs)
760
            elif self.dataset == "lm1b":
761
                data_iter = LMShuffledIterator(data, *args, **kwargs)
762

763
        return data_iter
764

765

766
def get_lm_corpus(datadir, dataset):
767
    fn = os.path.join(datadir, "cache.pt")
768
    fn_pickle = os.path.join(datadir, "cache.pkl")
769
    if os.path.exists(fn):
770
        logger.info("Loading cached dataset...")
771
        corpus = torch.load(fn_pickle)
772
    elif os.path.exists(fn):
773
        logger.info("Loading cached dataset from pickle...")
774
        with open(fn, "rb") as fp:
775
            corpus = pickle.load(fp)
776
    else:
777
        logger.info("Producing dataset {}...".format(dataset))
778
        kwargs = {}
779
        if dataset in ["wt103", "wt2"]:
780
            kwargs["special"] = ["<eos>"]
781
            kwargs["lower_case"] = False
782
        elif dataset == "ptb":
783
            kwargs["special"] = ["<eos>"]
784
            kwargs["lower_case"] = True
785
        elif dataset == "lm1b":
786
            kwargs["special"] = []
787
            kwargs["lower_case"] = False
788
            kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt")
789
        elif dataset in ["enwik8", "text8"]:
790
            pass
791

792
        corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
793
        torch.save(corpus, fn)
794

795
    return corpus
796

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.