CSS-LM
795 строк · 29.0 Кб
1# coding=utf-8
2# Copyright 2018 Google AI, Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
3# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16""" Tokenization classes for Transformer XL model.
17Adapted from https://github.com/kimiyoung/transformer-xl.
18"""
19
20
21import glob
22import logging
23import os
24import pickle
25import re
26from collections import Counter, OrderedDict
27from typing import Optional
28
29import numpy as np
30from tokenizers import Tokenizer
31from tokenizers.implementations import BaseTokenizer
32from tokenizers.models import WordLevel
33from tokenizers.normalizers import Lowercase, Sequence, Strip, unicode_normalizer_from_str
34from tokenizers.pre_tokenizers import CharDelimiterSplit, WhitespaceSplit
35from tokenizers.processors import BertProcessing
36
37from .file_utils import cached_path, is_torch_available
38from .tokenization_utils import PreTrainedTokenizer
39from .tokenization_utils_fast import PreTrainedTokenizerFast
40
41
42if is_torch_available():
43import torch
44
45
46logger = logging.getLogger(__name__)
47
48VOCAB_FILES_NAMES = {"pretrained_vocab_file": "vocab.bin", "vocab_file": "vocab.txt"}
49VOCAB_FILES_NAMES_FAST = {"pretrained_vocab_file": "vocab.json", "vocab_file": "vocab.json"}
50
51PRETRAINED_VOCAB_FILES_MAP = {
52"pretrained_vocab_file": {
53"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
54}
55}
56
57PRETRAINED_VOCAB_FILES_MAP_FAST = {
58"pretrained_vocab_file": {
59"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.json",
60}
61}
62
63PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
64"transfo-xl-wt103": None,
65}
66
67PRETRAINED_CORPUS_ARCHIVE_MAP = {
68"transfo-xl-wt103": "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
69}
70CORPUS_NAME = "corpus.bin"
71
72
73class TransfoXLTokenizer(PreTrainedTokenizer):
74"""
75Transformer-XL tokenizer adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
76
77This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
78should refer to the superclass for more information regarding methods.
79"""
80
81vocab_files_names = VOCAB_FILES_NAMES
82pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
83max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
84model_input_names = []
85
86def __init__(
87self,
88special=None,
89min_freq=0,
90max_size=None,
91lower_case=False,
92delimiter=None,
93vocab_file=None,
94pretrained_vocab_file=None,
95never_split=None,
96unk_token="<unk>",
97eos_token="<eos>",
98additional_special_tokens=["<formula>"],
99**kwargs
100):
101super().__init__(
102unk_token=unk_token, eos_token=eos_token, additional_special_tokens=additional_special_tokens, **kwargs
103)
104
105if never_split is None:
106never_split = self.all_special_tokens
107if special is None:
108special = []
109self.counter = Counter()
110self.special = special
111self.min_freq = min_freq
112self.max_size = max_size
113self.lower_case = lower_case
114self.delimiter = delimiter
115self.vocab_file = vocab_file
116self.never_split = never_split
117self.punctuation_symbols = '!"#$%&()*+,-./\\:;<=>?@[\\]^_`{|}~'
118self.punction_without_space_before_pattern = re.compile(r"[^\s][{}]".format(self.punctuation_symbols))
119self.punctuation_with_space_around_pattern = self._compile_space_around_punctuation_pattern()
120
121try:
122if pretrained_vocab_file is not None:
123# Hack because, honestly this tokenizer was not made to be used
124# in a library like ours, at all.
125vocab_dict = torch.load(pretrained_vocab_file)
126for key, value in vocab_dict.items():
127if key not in self.__dict__:
128self.__dict__[key] = value
129
130if vocab_file is not None:
131self.build_vocab()
132except Exception:
133raise ValueError(
134"Unable to parse file {}. Unknown format. "
135"If you tried to load a model saved through TransfoXLTokenizerFast,"
136"please note they are not compatible.".format(pretrained_vocab_file)
137)
138
139if vocab_file is not None:
140self.build_vocab()
141
142def _compile_space_around_punctuation_pattern(self):
143look_ahead_for_special_token = "(?=[{}])".format(self.punctuation_symbols)
144look_ahead_to_match_all_except_space = r"(?=[^\s])"
145return re.compile(r"" + look_ahead_for_special_token + look_ahead_to_match_all_except_space)
146
147def count_file(self, path, verbose=False, add_eos=False):
148if verbose:
149logger.info("counting file {} ...".format(path))
150assert os.path.exists(path)
151
152sents = []
153with open(path, "r", encoding="utf-8") as f:
154for idx, line in enumerate(f):
155if verbose and idx > 0 and idx % 500000 == 0:
156logger.info(" line {}".format(idx))
157symbols = self.tokenize(line, add_eos=add_eos)
158self.counter.update(symbols)
159sents.append(symbols)
160
161return sents
162
163def count_sents(self, sents, verbose=False):
164"""
165sents : a list of sentences, each a list of tokenized symbols
166"""
167if verbose:
168logger.info("counting {} sents ...".format(len(sents)))
169for idx, symbols in enumerate(sents):
170if verbose and idx > 0 and idx % 500000 == 0:
171logger.info(" line {}".format(idx))
172self.counter.update(symbols)
173
174def _build_from_file(self, vocab_file):
175self.idx2sym = []
176self.sym2idx = OrderedDict()
177
178with open(vocab_file, "r", encoding="utf-8") as f:
179for line in f:
180symb = line.strip().split()[0]
181self.add_symbol(symb)
182if "<UNK>" in self.sym2idx:
183self.unk_idx = self.sym2idx["<UNK>"]
184elif "<unk>" in self.sym2idx:
185self.unk_idx = self.sym2idx["<unk>"]
186else:
187raise ValueError("No <unkown> token in vocabulary")
188
189def save_vocabulary(self, vocab_path):
190"""
191Save the vocabulary and special tokens file to a directory.
192
193Args:
194vocab_path (:obj:`str`):
195The directory in which to save the vocabulary.
196
197Returns:
198:obj:`Tuple(str)`: Paths to the files saved.
199"""
200
201logger.warning(
202"Please note you will not be able to load the save vocabulary in"
203" Rust-based TransfoXLTokenizerFast as they don't share the same structure."
204)
205
206if os.path.isdir(vocab_path):
207vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["pretrained_vocab_file"])
208else:
209vocab_file = vocab_path
210torch.save(self.__dict__, vocab_file)
211return (vocab_file,)
212
213def build_vocab(self):
214if self.vocab_file:
215logger.info("building vocab from {}".format(self.vocab_file))
216self._build_from_file(self.vocab_file)
217logger.info("final vocab size {}".format(len(self)))
218else:
219logger.info("building vocab with min_freq={}, max_size={}".format(self.min_freq, self.max_size))
220self.idx2sym = []
221self.sym2idx = OrderedDict()
222
223for sym in self.special:
224self.add_special(sym)
225
226for sym, cnt in self.counter.most_common(self.max_size):
227if cnt < self.min_freq:
228break
229self.add_symbol(sym)
230
231logger.info("final vocab size {} from {} unique tokens".format(len(self), len(self.counter)))
232
233def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False):
234if verbose:
235logger.info("encoding file {} ...".format(path))
236assert os.path.exists(path)
237encoded = []
238with open(path, "r", encoding="utf-8") as f:
239for idx, line in enumerate(f):
240if verbose and idx > 0 and idx % 500000 == 0:
241logger.info(" line {}".format(idx))
242symbols = self.tokenize(line, add_eos=add_eos, add_double_eos=add_double_eos)
243encoded.append(self.convert_to_tensor(symbols))
244
245if ordered:
246encoded = torch.cat(encoded)
247
248return encoded
249
250def encode_sents(self, sents, ordered=False, verbose=False):
251if verbose:
252logger.info("encoding {} sents ...".format(len(sents)))
253encoded = []
254for idx, symbols in enumerate(sents):
255if verbose and idx > 0 and idx % 500000 == 0:
256logger.info(" line {}".format(idx))
257encoded.append(self.convert_to_tensor(symbols))
258
259if ordered:
260encoded = torch.cat(encoded)
261
262return encoded
263
264def add_special(self, sym):
265if sym not in self.sym2idx:
266self.idx2sym.append(sym)
267self.sym2idx[sym] = len(self.idx2sym) - 1
268setattr(self, "{}_idx".format(sym.strip("<>")), self.sym2idx[sym])
269
270def add_symbol(self, sym):
271if sym not in self.sym2idx:
272self.idx2sym.append(sym)
273self.sym2idx[sym] = len(self.idx2sym) - 1
274
275def move_added_token(self, token: str, target_idx: int):
276"""
277Moves an added token to a specific position in the vocab.
278This method should be used when resizing an embedding layer other than the last one in the `AdaptiveEmbedding`
279in order to move the token in the tokenizer from the default position (at the very end) to the desired one.
280
281Args:
282token: The token to move to a specific position in the vocab.
283target_idx: The position where the token should be moved to.
284"""
285assert token in self.added_tokens_encoder, "Token which should be moved has to be an added token"
286assert token not in self.idx2sym, "Token which should be moved is already in vocab"
287
288# Insert sym into vocab
289self.idx2sym.insert(target_idx, token)
290self.sym2idx[token] = target_idx
291
292# Shift following indices in sym2idx
293for idx in range(target_idx + 1, len(self.idx2sym)):
294current_sym = self.idx2sym[idx]
295self.sym2idx[current_sym] = idx
296
297# Delete token from added_tokens
298old_index = self.added_tokens_encoder[token]
299del self.added_tokens_decoder[old_index]
300del self.added_tokens_encoder[token]
301
302def _convert_id_to_token(self, idx):
303"""Converts an id in a token (BPE) using the vocab."""
304assert 0 <= idx < len(self), "Index {} out of vocabulary range".format(idx)
305return self.idx2sym[idx]
306
307def _convert_token_to_id(self, sym):
308""" Converts a token (str) in an id using the vocab. """
309if sym in self.sym2idx:
310return self.sym2idx[sym]
311else:
312# logger.info('encounter unk {}'.format(sym))
313# assert '<eos>' not in sym
314if hasattr(self, "unk_idx"):
315return self.sym2idx.get(sym, self.unk_idx)
316# Backward compatibility with pre-trained models
317elif "<unk>" in self.sym2idx:
318return self.sym2idx["<unk>"]
319elif "<UNK>" in self.sym2idx:
320return self.sym2idx["<UNK>"]
321else:
322raise ValueError("Token not in vocabulary and no <unk> token in vocabulary for replacement")
323
324def convert_tokens_to_string(self, tokens):
325""" Converts a sequence of tokens (string) in a single string. """
326out_string = " ".join(tokens).strip()
327return out_string
328
329def convert_to_tensor(self, symbols):
330return torch.LongTensor(self.convert_tokens_to_ids(symbols))
331
332@property
333def vocab_size(self):
334return len(self.idx2sym)
335
336def get_vocab(self):
337return dict(self.sym2idx, **self.added_tokens_encoder)
338
339def _tokenize(self, line, add_eos=False, add_double_eos=False):
340line = line.strip()
341# convert to lower case
342if self.lower_case:
343line = line.lower()
344
345# empty delimiter '' will evaluate False
346if self.delimiter == "":
347symbols = line
348else:
349symbols = line.split(self.delimiter)
350
351if add_double_eos: # lm1b
352return ["<S>"] + symbols + ["<S>"]
353elif add_eos:
354return symbols + ["<eos>"]
355else:
356return symbols
357
358def prepare_for_tokenization(self, text, is_pretokenized=False, **kwargs):
359# add spaces before punctuation symbols as should be done in transfo-xl
360add_space_before_punct_symbol = kwargs.pop("add_space_before_punct_symbol", False)
361if add_space_before_punct_symbol:
362text = self.punctuation_with_space_around_pattern.sub(r" ", text)
363elif self.punction_without_space_before_pattern.search(text):
364# searches until the first occurence of a punctuation symbol without surrounding spaces
365logger.warning(
366"You might want to consider setting `add_space_before_punct_symbol=True` as an argument to the `tokenizer.encode()` to avoid tokenizing words with punctuation symbols to the `<unk>` token"
367)
368
369return (text, kwargs)
370
371
372class _TransfoXLDelimiterLookupTokenizer(BaseTokenizer):
373def __init__(
374self,
375vocab_file,
376delimiter,
377lowercase,
378unk_token,
379eos_token,
380add_eos=False,
381add_double_eos=False,
382normalization: Optional[str] = None,
383):
384
385try:
386tokenizer = WordLevel(vocab_file, unk_token=unk_token)
387tokenizer = Tokenizer(tokenizer)
388except Exception:
389raise ValueError(
390"Unable to parse file {}. Unknown format. "
391"If you tried to load a model saved through TransfoXLTokenizer,"
392"please note they are not compatible.".format(vocab_file)
393)
394
395# Create the correct normalization path
396normalizer = []
397
398# Include unicode normalization
399if normalization:
400normalizer += [unicode_normalizer_from_str(normalization)]
401
402# Include case normalization
403if lowercase:
404normalizer += [Lowercase()]
405
406# Strip normalizer at the end
407normalizer += [Strip(left=True, right=True)]
408
409if len(normalizer) > 0:
410tokenizer.normalizer = Sequence(normalizer) if len(normalizer) > 1 else normalizer[0]
411
412# Setup the splitter
413tokenizer.pre_tokenizer = CharDelimiterSplit(delimiter) if delimiter else WhitespaceSplit()
414
415if add_double_eos:
416tokenizer.post_processor = BertProcessing(
417(eos_token, tokenizer.token_to_id(eos_token)), (eos_token, tokenizer.token_to_id(eos_token))
418)
419
420parameters = {
421"model": "TransfoXLModel",
422"add_eos": add_eos,
423"add_double_eos": add_double_eos,
424"unk_token": unk_token,
425"eos_token": eos_token,
426"delimiter": delimiter,
427"lowercase": lowercase,
428}
429
430super().__init__(tokenizer, parameters)
431
432
433class TransfoXLTokenizerFast(PreTrainedTokenizerFast):
434"""
435Construct a "Fast" Transformer-XL tokenizer (backed by HuggingFace's `tokenizers` library).
436
437The Transformer-XL tokenizer is a word-level tokenizer (no sub-word tokenization).
438
439Adapted from Vocab class in https://github.com/kimiyoung/transformer-xl
440
441This tokenizer inherits from :class:`~transformers.PreTrainedTokenizerFast` which contains most of the methods. Users
442should refer to the superclass for more information regarding methods.
443"""
444
445vocab_files_names = VOCAB_FILES_NAMES_FAST
446pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP_FAST
447max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
448model_input_names = []
449
450def __init__(
451self,
452special=None,
453min_freq=0,
454max_size=None,
455lower_case=False,
456delimiter=None,
457vocab_file=None,
458pretrained_vocab_file=None,
459never_split=None,
460unk_token="<unk>",
461eos_token="<eos>",
462additional_special_tokens=["<formula>"],
463add_eos=False,
464add_double_eos=False,
465normalization=None,
466**kwargs
467):
468
469super().__init__(
470_TransfoXLDelimiterLookupTokenizer(
471vocab_file=vocab_file or pretrained_vocab_file,
472delimiter=delimiter,
473lowercase=lower_case,
474unk_token=unk_token,
475eos_token=eos_token,
476add_eos=add_eos,
477add_double_eos=add_double_eos,
478normalization=normalization,
479),
480unk_token=unk_token,
481eos_token=eos_token,
482additional_special_tokens=additional_special_tokens,
483**kwargs,
484)
485
486def save_pretrained(self, save_directory):
487logger.warning(
488"Please note you will not be able to load the vocabulary in"
489" Python-based TransfoXLTokenizer as they don't share the same structure."
490)
491
492return super().save_pretrained(save_directory)
493
494
495class LMOrderedIterator(object):
496def __init__(self, data, bsz, bptt, device="cpu", ext_len=None):
497"""
498data -- LongTensor -- the LongTensor is strictly ordered
499"""
500self.bsz = bsz
501self.bptt = bptt
502self.ext_len = ext_len if ext_len is not None else 0
503
504self.device = device
505
506# Work out how cleanly we can divide the dataset into bsz parts.
507self.n_step = data.size(0) // bsz
508
509# Trim off any extra elements that wouldn't cleanly fit (remainders).
510data = data.narrow(0, 0, self.n_step * bsz)
511
512# Evenly divide the data across the bsz batches.
513self.data = data.view(bsz, -1).t().contiguous().to(device)
514
515# Number of mini-batches
516self.n_batch = (self.n_step + self.bptt - 1) // self.bptt
517
518def get_batch(self, i, bptt=None):
519if bptt is None:
520bptt = self.bptt
521seq_len = min(bptt, self.data.size(0) - 1 - i)
522
523end_idx = i + seq_len
524beg_idx = max(0, i - self.ext_len)
525
526data = self.data[beg_idx:end_idx]
527target = self.data[i + 1 : i + 1 + seq_len]
528
529data_out = data.transpose(0, 1).contiguous().to(self.device)
530target_out = target.transpose(0, 1).contiguous().to(self.device)
531
532return data_out, target_out, seq_len
533
534def get_fixlen_iter(self, start=0):
535for i in range(start, self.data.size(0) - 1, self.bptt):
536yield self.get_batch(i)
537
538def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3):
539max_len = self.bptt + max_deviation * std
540i = start
541while True:
542bptt = self.bptt if np.random.random() < 0.95 else self.bptt / 2.0
543bptt = min(max_len, max(min_len, int(np.random.normal(bptt, std))))
544data, target, seq_len = self.get_batch(i, bptt)
545i += seq_len
546yield data, target, seq_len
547if i >= self.data.size(0) - 2:
548break
549
550def __iter__(self):
551return self.get_fixlen_iter()
552
553
554class LMShuffledIterator(object):
555def __init__(self, data, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
556"""
557data -- list[LongTensor] -- there is no order among the LongTensors
558"""
559self.data = data
560
561self.bsz = bsz
562self.bptt = bptt
563self.ext_len = ext_len if ext_len is not None else 0
564
565self.device = device
566self.shuffle = shuffle
567
568def get_sent_stream(self):
569# index iterator
570epoch_indices = np.random.permutation(len(self.data)) if self.shuffle else np.array(range(len(self.data)))
571
572# sentence iterator
573for idx in epoch_indices:
574yield self.data[idx]
575
576def stream_iterator(self, sent_stream):
577# streams for each data in the batch
578streams = [None] * self.bsz
579
580data = torch.LongTensor(self.bptt, self.bsz)
581target = torch.LongTensor(self.bptt, self.bsz)
582
583n_retain = 0
584
585while True:
586# data : [n_retain+bptt x bsz]
587# target : [bptt x bsz]
588data[n_retain:].fill_(-1)
589target.fill_(-1)
590
591valid_batch = True
592
593for i in range(self.bsz):
594n_filled = 0
595try:
596while n_filled < self.bptt:
597if streams[i] is None or len(streams[i]) <= 1:
598streams[i] = next(sent_stream)
599# number of new tokens to fill in
600n_new = min(len(streams[i]) - 1, self.bptt - n_filled)
601# first n_retain tokens are retained from last batch
602data[n_retain + n_filled : n_retain + n_filled + n_new, i] = streams[i][:n_new]
603target[n_filled : n_filled + n_new, i] = streams[i][1 : n_new + 1]
604streams[i] = streams[i][n_new:]
605n_filled += n_new
606except StopIteration:
607valid_batch = False
608break
609
610if not valid_batch:
611return
612
613data_out = data.transpose(0, 1).contiguous().to(self.device)
614target_out = target.transpose(0, 1).contiguous().to(self.device)
615
616yield data_out, target_out, self.bptt
617
618n_retain = min(data.size(0), self.ext_len)
619if n_retain > 0:
620data[:n_retain] = data[-n_retain:]
621data.resize_(n_retain + self.bptt, data.size(1))
622
623def __iter__(self):
624# sent_stream is an iterator
625sent_stream = self.get_sent_stream()
626
627for batch in self.stream_iterator(sent_stream):
628yield batch
629
630
631class LMMultiFileIterator(LMShuffledIterator):
632def __init__(self, paths, vocab, bsz, bptt, device="cpu", ext_len=None, shuffle=False):
633
634self.paths = paths
635self.vocab = vocab
636
637self.bsz = bsz
638self.bptt = bptt
639self.ext_len = ext_len if ext_len is not None else 0
640
641self.device = device
642self.shuffle = shuffle
643
644def get_sent_stream(self, path):
645sents = self.vocab.encode_file(path, add_double_eos=True)
646if self.shuffle:
647np.random.shuffle(sents)
648sent_stream = iter(sents)
649
650return sent_stream
651
652def __iter__(self):
653if self.shuffle:
654np.random.shuffle(self.paths)
655
656for path in self.paths:
657# sent_stream is an iterator
658sent_stream = self.get_sent_stream(path)
659for batch in self.stream_iterator(sent_stream):
660yield batch
661
662
663class TransfoXLCorpus(object):
664@classmethod
665def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
666"""
667Instantiate a pre-processed corpus.
668"""
669vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
670if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
671corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
672else:
673corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
674# redirect to the cache, if necessary
675try:
676resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
677except EnvironmentError:
678logger.error(
679"Corpus '{}' was not found in corpus list ({}). "
680"We assumed '{}' was a path or url but couldn't find files {} "
681"at this path or url.".format(
682pretrained_model_name_or_path,
683", ".join(PRETRAINED_CORPUS_ARCHIVE_MAP.keys()),
684pretrained_model_name_or_path,
685corpus_file,
686)
687)
688return None
689if resolved_corpus_file == corpus_file:
690logger.info("loading corpus file {}".format(corpus_file))
691else:
692logger.info("loading corpus file {} from cache at {}".format(corpus_file, resolved_corpus_file))
693
694# Instantiate tokenizer.
695corpus = cls(*inputs, **kwargs)
696corpus_dict = torch.load(resolved_corpus_file)
697for key, value in corpus_dict.items():
698corpus.__dict__[key] = value
699corpus.vocab = vocab
700if corpus.train is not None:
701corpus.train = torch.tensor(corpus.train, dtype=torch.long)
702if corpus.valid is not None:
703corpus.valid = torch.tensor(corpus.valid, dtype=torch.long)
704if corpus.test is not None:
705corpus.test = torch.tensor(corpus.test, dtype=torch.long)
706return corpus
707
708def __init__(self, *args, **kwargs):
709self.vocab = TransfoXLTokenizer(*args, **kwargs)
710self.dataset = None
711self.train = None
712self.valid = None
713self.test = None
714
715def build_corpus(self, path, dataset):
716self.dataset = dataset
717
718if self.dataset in ["ptb", "wt2", "enwik8", "text8"]:
719self.vocab.count_file(os.path.join(path, "train.txt"))
720self.vocab.count_file(os.path.join(path, "valid.txt"))
721self.vocab.count_file(os.path.join(path, "test.txt"))
722elif self.dataset == "wt103":
723self.vocab.count_file(os.path.join(path, "train.txt"))
724elif self.dataset == "lm1b":
725train_path_pattern = os.path.join(
726path,
727"1-billion-word-language-modeling-benchmark-r13output",
728"training-monolingual.tokenized.shuffled",
729"news.en-*",
730)
731train_paths = glob.glob(train_path_pattern)
732# the vocab will load from file when build_vocab() is called
733
734self.vocab.build_vocab()
735
736if self.dataset in ["ptb", "wt2", "wt103"]:
737self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True)
738self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True)
739self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True)
740elif self.dataset in ["enwik8", "text8"]:
741self.train = self.vocab.encode_file(os.path.join(path, "train.txt"), ordered=True, add_eos=False)
742self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=True, add_eos=False)
743self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=True, add_eos=False)
744elif self.dataset == "lm1b":
745self.train = train_paths
746self.valid = self.vocab.encode_file(os.path.join(path, "valid.txt"), ordered=False, add_double_eos=True)
747self.test = self.vocab.encode_file(os.path.join(path, "test.txt"), ordered=False, add_double_eos=True)
748
749def get_iterator(self, split, *args, **kwargs):
750if split == "train":
751if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
752data_iter = LMOrderedIterator(self.train, *args, **kwargs)
753elif self.dataset == "lm1b":
754kwargs["shuffle"] = True
755data_iter = LMMultiFileIterator(self.train, self.vocab, *args, **kwargs)
756elif split in ["valid", "test"]:
757data = self.valid if split == "valid" else self.test
758if self.dataset in ["ptb", "wt2", "wt103", "enwik8", "text8"]:
759data_iter = LMOrderedIterator(data, *args, **kwargs)
760elif self.dataset == "lm1b":
761data_iter = LMShuffledIterator(data, *args, **kwargs)
762
763return data_iter
764
765
766def get_lm_corpus(datadir, dataset):
767fn = os.path.join(datadir, "cache.pt")
768fn_pickle = os.path.join(datadir, "cache.pkl")
769if os.path.exists(fn):
770logger.info("Loading cached dataset...")
771corpus = torch.load(fn_pickle)
772elif os.path.exists(fn):
773logger.info("Loading cached dataset from pickle...")
774with open(fn, "rb") as fp:
775corpus = pickle.load(fp)
776else:
777logger.info("Producing dataset {}...".format(dataset))
778kwargs = {}
779if dataset in ["wt103", "wt2"]:
780kwargs["special"] = ["<eos>"]
781kwargs["lower_case"] = False
782elif dataset == "ptb":
783kwargs["special"] = ["<eos>"]
784kwargs["lower_case"] = True
785elif dataset == "lm1b":
786kwargs["special"] = []
787kwargs["lower_case"] = False
788kwargs["vocab_file"] = os.path.join(datadir, "1b_word_vocab.txt")
789elif dataset in ["enwik8", "text8"]:
790pass
791
792corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
793torch.save(corpus, fn)
794
795return corpus
796