lmops
153 строки · 6.8 Кб
1import torch2import tqdm3import numpy as np4
5from contextlib import nullcontext6from torch.nn import CrossEntropyLoss7from torch.utils.data import DataLoader8from typing import List9from datasets import Dataset10from transformers import AutoTokenizer, AutoModelForCausalLM11from transformers.models.gpt2 import GPT2LMHeadModel, GPT2TokenizerFast12from transformers.generation.utils import GreedySearchDecoderOnlyOutput13from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions14
15from utils import move_to_device16from logger_config import logger17from config import Arguments18from llms.base_llm import BaseLLM19from collators.gpt2_collator import ScoreCollator, DecodeCollator20
21
22class GPT2(BaseLLM):23
24def __init__(self, args: Arguments, model_name_or_path: str = 'gpt2-xl', **kwargs):25super().__init__(model_name_or_path, **kwargs)26self.args = args27self.tokenizer: GPT2TokenizerFast = AutoTokenizer.from_pretrained(model_name_or_path)28self.tokenizer.pad_token_id = self.tokenizer.eos_token_id29self.tokenizer.truncation_side = 'left'30self.batch_size_per_device = args.llm_batch_size_per_device31
32dtype = torch.float16 if args.fp16 else torch.float3233self.model: GPT2LMHeadModel = AutoModelForCausalLM.from_pretrained(model_name_or_path, torch_dtype=dtype)34self.model.config.pad_token_id = self.model.config.eos_token_id35self.model.eval()36
37@torch.no_grad()38def batch_score(39self, input_texts: List[str], output_texts: List[str],40delimiter: str = '\n', **kwargs41) -> List[float]:42assert len(input_texts) == len(output_texts), '{} != {}'.format(len(input_texts), len(output_texts))43assert not all(output in ['A', 'B', 'C', 'D'] for output in output_texts), 'output_texts should not be letters'44
45collator = ScoreCollator(46tokenizer=self.tokenizer,47max_length=self.args.llm_max_input_length,48pad_to_multiple_of=8,49delimiter=delimiter,50)51
52dataset = Dataset.from_dict({53'input_texts': input_texts,54'output_texts': output_texts55})56data_loader = DataLoader(57dataset,58batch_size=self.batch_size_per_device,59shuffle=False,60num_workers=2,61collate_fn=collator,62pin_memory=True63)64
65avg_log_probs: List[float] = []66for batch_dict in tqdm.tqdm(data_loader, desc='batch score', mininterval=10, disable=len(dataset) < 1024):67# Hack: remove token_type_ids for llama model68if 'llama' in self.model_name_or_path and 'token_type_ids' in batch_dict:69del batch_dict['token_type_ids']70
71batch_dict = move_to_device(batch_dict, device=self.model.device)72with torch.cuda.amp.autocast() if self.args.fp16 else nullcontext():73outputs: CausalLMOutputWithCrossAttentions = self.model(74**batch_dict, return_dict=True, use_cache=False75)76
77labels = batch_dict['labels']78# Shift so that tokens < n predict n79shift_logits = outputs.logits[..., :-1, :].contiguous()80shift_labels = labels[..., 1:].contiguous()81# Flatten the tokens82loss_fct = CrossEntropyLoss(reduction='none')83per_token_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))84per_sequence_loss = per_token_loss.view(batch_dict['input_ids'].size(0), -1).sum(dim=1)85# divide by the number of valid labels86num_valid_labels = torch.sum(labels != -100, dim=1).float()87avg_log_probs += (-per_sequence_loss / num_valid_labels).cpu().tolist()88
89logger.debug('num_valid_labels: {}, loss: {}, per_token_loss: {}, avg_per_token_loss: {}'.format(90num_valid_labels, outputs.loss, per_token_loss,91per_token_loss.sum() / torch.sum(labels != -100).float())92)93
94return avg_log_probs95
96def batch_decode(self, input_texts: List[str], prefix_trie=None, **kwargs) -> List[str]:97collator = DecodeCollator(98tokenizer=self.tokenizer,99max_length=self.args.llm_max_input_length,100pad_to_multiple_of=8101)102dataset: Dataset = Dataset.from_dict({'input_texts': input_texts})103data_loader = DataLoader(104dataset,105batch_size=self.batch_size_per_device,106shuffle=False,107num_workers=2,108collate_fn=collator,109pin_memory=True110)111
112decoded_texts: List[str] = []113eos_token_id: int = self.tokenizer.encode('\n')[-1]114for batch_dict in tqdm.tqdm(data_loader, mininterval=10, desc='batch decode'):115# Hack: remove token_type_ids for llama model116if 'llama' in self.model_name_or_path and 'token_type_ids' in batch_dict:117del batch_dict['token_type_ids']118
119batch_dict = move_to_device(batch_dict, device=self.model.device)120input_len: int = batch_dict['input_ids'].shape[1]121
122def _prefix_allowed_tokens_fn(_, generated_ids):123return prefix_trie.get(generated_ids.tolist()[input_len:])124
125with torch.cuda.amp.autocast() if self.args.fp16 else nullcontext():126outputs: GreedySearchDecoderOnlyOutput = self.model.generate(127**batch_dict,128num_beams=1,129do_sample=False,130max_new_tokens=self.args.llm_max_decode_length,131begin_suppress_tokens=[eos_token_id],132eos_token_id=eos_token_id,133prefix_allowed_tokens_fn=_prefix_allowed_tokens_fn if prefix_trie else None,134return_dict_in_generate=True,135output_scores=False,136)137generated_token_ids = outputs.sequences[:, input_len:]138logger.debug('generated_token_ids: {}'.format(generated_token_ids.tolist()))139
140if outputs.scores is not None:141transition_scores = self.model.compute_transition_scores(142outputs.sequences, outputs.scores, normalize_logits=True143)144for tok, score in zip(generated_token_ids[0].cpu(), transition_scores[0].cpu()):145if tok in self.tokenizer.all_special_ids:146continue147# | token | token string | logits | probability148logger.info(f"| {tok:5d} | {self.tokenizer.decode(tok):8s} | {score.numpy():.4f} "149f"| {np.exp(score.numpy()):.2%}")150
151decoded_texts += self.tokenizer.batch_decode(generated_token_ids, skip_special_tokens=True)152
153return decoded_texts154