lmops

Форк
0
153 строки · 6.8 Кб
1
import torch
2
import tqdm
3
import numpy as np
4

5
from contextlib import nullcontext
6
from torch.nn import CrossEntropyLoss
7
from torch.utils.data import DataLoader
8
from typing import List
9
from datasets import Dataset
10
from transformers import AutoTokenizer, AutoModelForCausalLM
11
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2TokenizerFast
12
from transformers.generation.utils import GreedySearchDecoderOnlyOutput
13
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
14

15
from utils import move_to_device
16
from logger_config import logger
17
from config import Arguments
18
from llms.base_llm import BaseLLM
19
from collators.gpt2_collator import ScoreCollator, DecodeCollator
20

21

22
class GPT2(BaseLLM):
23

24
    def __init__(self, args: Arguments, model_name_or_path: str = 'gpt2-xl', **kwargs):
25
        super().__init__(model_name_or_path, **kwargs)
26
        self.args = args
27
        self.tokenizer: GPT2TokenizerFast = AutoTokenizer.from_pretrained(model_name_or_path)
28
        self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
29
        self.tokenizer.truncation_side = 'left'
30
        self.batch_size_per_device = args.llm_batch_size_per_device
31

32
        dtype = torch.float16 if args.fp16 else torch.float32
33
        self.model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype)
34
        self.model.config.pad_token_id = self.model.config.eos_token_id
35
        self.model.eval()
36

37
    @torch.no_grad()
38
    def batch_score(
39
            self, input_texts: List[str], output_texts: List[str],
40
            delimiter: str = '\n', **kwargs
41
    ) -> List[float]:
42
        assert len(input_texts) == len(output_texts), '{} != {}'.format(len(input_texts), len(output_texts))
43
        assert not all(output in ['A', 'B', 'C', 'D'] for output in output_texts), 'output_texts should not be letters'
44

45
        collator = ScoreCollator(
46
            tokenizer=self.tokenizer,
47
            max_length=self.args.llm_max_input_length,
48
            pad_to_multiple_of=8,
49
            delimiter=delimiter,
50
        )
51

52
        dataset = Dataset.from_dict({
53
            'input_texts': input_texts,
54
            'output_texts': output_texts
55
        })
56
        data_loader = DataLoader(
57
            dataset,
58
            batch_size=self.batch_size_per_device,
59
            shuffle=False,
60
            num_workers=2,
61
            collate_fn=collator,
62
            pin_memory=True
63
        )
64

65
        avg_log_probs: List[float] = []
66
        for batch_dict in tqdm.tqdm(data_loader, desc='batch score', mininterval=10, disable=len(dataset) < 1024):
67
            # Hack: remove token_type_ids for llama model
68
            if 'llama' in self.model_name_or_path and 'token_type_ids' in batch_dict:
69
                del batch_dict['token_type_ids']
70

71
            batch_dict = move_to_device(batch_dict, device=self.model.device)
72
            with torch.cuda.amp.autocast() if self.args.fp16 else nullcontext():
73
                outputs: CausalLMOutputWithCrossAttentions = self.model(
74
                    **batch_dict, return_dict=True, use_cache=False
75
                )
76

77
                labels = batch_dict['labels']
78
                # Shift so that tokens < n predict n
79
                shift_logits = outputs.logits[..., :-1, :].contiguous()
80
                shift_labels = labels[..., 1:].contiguous()
81
                # Flatten the tokens
82
                loss_fct = CrossEntropyLoss(reduction='none')
83
                per_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
84
                per_sequence_loss = per_token_loss.view(batch_dict['input_ids'].size(0), -1).sum(dim=1)
85
                # divide by the number of valid labels
86
                num_valid_labels = torch.sum(labels != -100, dim=1).float()
87
                avg_log_probs += (-per_sequence_loss / num_valid_labels).cpu().tolist()
88

89
                logger.debug('num_valid_labels: {}, loss: {}, per_token_loss: {}, avg_per_token_loss: {}'.format(
90
                    num_valid_labels, outputs.loss, per_token_loss,
91
                    per_token_loss.sum() / torch.sum(labels != -100).float())
92
                )
93

94
        return avg_log_probs
95

96
    def batch_decode(self, input_texts: List[str], prefix_trie=None, **kwargs) -> List[str]:
97
        collator = DecodeCollator(
98
            tokenizer=self.tokenizer,
99
            max_length=self.args.llm_max_input_length,
100
            pad_to_multiple_of=8
101
        )
102
        dataset: Dataset = Dataset.from_dict({'input_texts': input_texts})
103
        data_loader = DataLoader(
104
            dataset,
105
            batch_size=self.batch_size_per_device,
106
            shuffle=False,
107
            num_workers=2,
108
            collate_fn=collator,
109
            pin_memory=True
110
        )
111

112
        decoded_texts: List[str] = []
113
        eos_token_id: int = self.tokenizer.encode('\n')[-1]
114
        for batch_dict in tqdm.tqdm(data_loader, mininterval=10, desc='batch decode'):
115
            # Hack: remove token_type_ids for llama model
116
            if 'llama' in self.model_name_or_path and 'token_type_ids' in batch_dict:
117
                del batch_dict['token_type_ids']
118

119
            batch_dict = move_to_device(batch_dict, device=self.model.device)
120
            input_len: int = batch_dict['input_ids'].shape[1]
121

122
            def _prefix_allowed_tokens_fn(_, generated_ids):
123
                return prefix_trie.get(generated_ids.tolist()[input_len:])
124

125
            with torch.cuda.amp.autocast() if self.args.fp16 else nullcontext():
126
                outputs: GreedySearchDecoderOnlyOutput = self.model.generate(
127
                    **batch_dict,
128
                    num_beams=1,
129
                    do_sample=False,
130
                    max_new_tokens=self.args.llm_max_decode_length,
131
                    begin_suppress_tokens=[eos_token_id],
132
                    eos_token_id=eos_token_id,
133
                    prefix_allowed_tokens_fn=_prefix_allowed_tokens_fn if prefix_trie else None,
134
                    return_dict_in_generate=True,
135
                    output_scores=False,
136
                )
137
                generated_token_ids = outputs.sequences[:, input_len:]
138
                logger.debug('generated_token_ids: {}'.format(generated_token_ids.tolist()))
139

140
                if outputs.scores is not None:
141
                    transition_scores = self.model.compute_transition_scores(
142
                        outputs.sequences, outputs.scores, normalize_logits=True
143
                    )
144
                    for tok, score in zip(generated_token_ids[0].cpu(), transition_scores[0].cpu()):
145
                        if tok in self.tokenizer.all_special_ids:
146
                            continue
147
                        # | token | token string | logits | probability
148
                        logger.info(f"| {tok:5d} | {self.tokenizer.decode(tok):8s} | {score.numpy():.4f} "
149
                                    f"| {np.exp(score.numpy()):.2%}")
150

151
            decoded_texts += self.tokenizer.batch_decode(generated_token_ids, skip_special_tokens=True)
152

153
        return decoded_texts
154

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.