colossalai
152 строки · 6.4 Кб
1from typing import Any, Callable, Optional2
3import torch4import torch.distributed as dist5from transformers import PreTrainedTokenizer6
7from .base import Actor8
9try:10from transformers.generation_logits_process import (11LogitsProcessorList,12TemperatureLogitsWarper,13TopKLogitsWarper,14TopPLogitsWarper,15)16except ImportError:17from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper18
19
20def _prepare_logits_processor(21top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None22) -> LogitsProcessorList:23processor_list = LogitsProcessorList()24if temperature is not None and temperature != 1.0:25processor_list.append(TemperatureLogitsWarper(temperature))26if top_k is not None and top_k != 0:27processor_list.append(TopKLogitsWarper(top_k))28if top_p is not None and top_p < 1.0:29processor_list.append(TopPLogitsWarper(top_p))30return processor_list31
32
33def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:34if dist.is_initialized() and dist.get_world_size() > 1:35# consider DP36unfinished_sequences = unfinished_sequences.clone()37dist.all_reduce(unfinished_sequences)38return unfinished_sequences.max() == 039
40
41def _sample(42model: Actor,43input_ids: torch.Tensor,44max_length: int,45early_stopping: bool = False,46eos_token_id: Optional[int] = None,47pad_token_id: Optional[int] = None,48top_k: Optional[int] = None,49top_p: Optional[float] = None,50temperature: Optional[float] = None,51prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,52update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,53**model_kwargs,54) -> torch.Tensor:55if input_ids.size(1) >= max_length:56return input_ids57
58logits_processor = _prepare_logits_processor(top_k, top_p, temperature)59unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)60
61for _ in range(input_ids.size(1), max_length):62model_inputs = (63prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}64)65outputs = model(**model_inputs)66
67# NOTE: this is correct only in left padding mode68next_token_logits = outputs["logits"][:, -1, :]69next_token_logits = logits_processor(input_ids, next_token_logits)70# sample71probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)72next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)73
74# finished sentences should have their next token be a padding token75if eos_token_id is not None:76assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."77next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)78
79# update generated ids, model inputs for next step80input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)81if update_model_kwargs_fn is not None:82model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)83
84# if eos_token was found in one sentence, set sentence to finished85if eos_token_id is not None:86unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())87
88# stop when each sentence is finished if early_stopping=True89if early_stopping and _is_sequence_finished(unfinished_sequences):90break91
92return input_ids93
94
95@torch.no_grad()96def generate(97model: Actor,98input_ids: torch.Tensor,99tokenizer: PreTrainedTokenizer,100max_length: int,101num_beams: int = 1,102do_sample: bool = True,103early_stopping: bool = False,104top_k: Optional[int] = None,105top_p: Optional[float] = None,106temperature: Optional[float] = None,107prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,108update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,109**model_kwargs,110) -> torch.Tensor:111"""Generate token sequence. The returned sequence is input_ids + generated_tokens.112
113Args:
114model (nn.Module): model
115input_ids (torch.Tensor): input sequence
116max_length (int): max length of the returned sequence
117num_beams (int, optional): number of beams. Defaults to 1.
118do_sample (bool, optional): whether to do sample. Defaults to True.
119early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
120top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
121top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
122temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
123prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
124update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
125"""
126assert tokenizer.padding_side == "left", "Current generation only supports left padding."127is_greedy_gen_mode = (num_beams == 1) and do_sample is False128is_sample_gen_mode = (num_beams == 1) and do_sample is True129is_beam_gen_mode = (num_beams > 1) and do_sample is False130if is_greedy_gen_mode:131# run greedy search132raise NotImplementedError133elif is_sample_gen_mode:134# run sample135return _sample(136model,137input_ids,138max_length,139early_stopping=early_stopping,140eos_token_id=tokenizer.eos_token_id,141pad_token_id=tokenizer.pad_token_id,142top_k=top_k,143top_p=top_p,144temperature=temperature,145prepare_inputs_fn=prepare_inputs_fn,146update_model_kwargs_fn=update_model_kwargs_fn,147**model_kwargs,148)149elif is_beam_gen_mode:150raise NotImplementedError151else:152raise ValueError("Unsupported generation mode")153