CSS-LM

Форк
0
/
language_modeling.py 
101 строка · 3.8 Кб
1
import logging
2
import os
3
import pickle
4
import time
5

6
import torch
7
from filelock import FileLock
8
from torch.utils.data.dataset import Dataset
9

10
from ...tokenization_utils import PreTrainedTokenizer
11

12

13
logger = logging.getLogger(__name__)
14

15

16
class TextDataset(Dataset):
17
    """
18
    This will be superseded by a framework-agnostic approach
19
    soon.
20
    """
21

22
    def __init__(
23
        self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int, overwrite_cache=False,
24
    ):
25
        assert os.path.isfile(file_path)
26

27
        block_size = block_size - tokenizer.num_special_tokens_to_add(pair=False)
28

29
        directory, filename = os.path.split(file_path)
30
        cached_features_file = os.path.join(
31
            directory, "cached_lm_{}_{}_{}".format(tokenizer.__class__.__name__, str(block_size), filename,),
32
        )
33

34
        # Make sure only the first process in distributed training processes the dataset,
35
        # and the others will use the cache.
36
        lock_path = cached_features_file + ".lock"
37
        with FileLock(lock_path):
38

39
            if os.path.exists(cached_features_file) and not overwrite_cache:
40
                start = time.time()
41
                with open(cached_features_file, "rb") as handle:
42
                    self.examples = pickle.load(handle)
43
                logger.info(
44
                    f"Loading features from cached file {cached_features_file} [took %.3f s]", time.time() - start
45
                )
46

47
            else:
48
                logger.info(f"Creating features from dataset file at {directory}")
49

50
                self.examples = []
51
                with open(file_path, encoding="utf-8") as f:
52
                    text = f.read()
53

54
                tokenized_text = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
55

56
                for i in range(0, len(tokenized_text) - block_size + 1, block_size):  # Truncate in block of block_size
57
                    self.examples.append(
58
                        tokenizer.build_inputs_with_special_tokens(tokenized_text[i : i + block_size])
59
                    )
60
                # Note that we are losing the last truncated example here for the sake of simplicity (no padding)
61
                # If your dataset is small, first you should loook for a bigger one :-) and second you
62
                # can change this behavior by adding (model specific) padding.
63

64
                start = time.time()
65
                with open(cached_features_file, "wb") as handle:
66
                    pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
67
                logger.info(
68
                    "Saving features into cached file %s [took %.3f s]", cached_features_file, time.time() - start
69
                )
70

71
    def __len__(self):
72
        return len(self.examples)
73

74
    def __getitem__(self, i) -> torch.Tensor:
75
        return torch.tensor(self.examples[i], dtype=torch.long)
76

77

78
class LineByLineTextDataset(Dataset):
79
    """
80
    This will be superseded by a framework-agnostic approach
81
    soon.
82
    """
83

84
    def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
85
        assert os.path.isfile(file_path)
86
        # Here, we do not cache the features, operating under the assumption
87
        # that we will soon use fast multithreaded tokenizers from the
88
        # `tokenizers` repo everywhere =)
89
        logger.info("Creating features from dataset file at %s", file_path)
90

91
        with open(file_path, encoding="utf-8") as f:
92
            lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]
93

94
        batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)
95
        self.examples = batch_encoding["input_ids"]
96

97
    def __len__(self):
98
        return len(self.examples)
99

100
    def __getitem__(self, i) -> torch.Tensor:
101
        return torch.tensor(self.examples[i], dtype=torch.long)
102

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.