colossalai

Форк
0
152 строки · 6.4 Кб
1
from typing import Any, Callable, Optional
2

3
import torch
4
import torch.distributed as dist
5
from transformers import PreTrainedTokenizer
6

7
from .base import Actor
8

9
try:
10
    from transformers.generation_logits_process import (
11
        LogitsProcessorList,
12
        TemperatureLogitsWarper,
13
        TopKLogitsWarper,
14
        TopPLogitsWarper,
15
    )
16
except ImportError:
17
    from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
18

19

20
def _prepare_logits_processor(
21
    top_k: Optional[int] = None, top_p: Optional[float] = None, temperature: Optional[float] = None
22
) -> LogitsProcessorList:
23
    processor_list = LogitsProcessorList()
24
    if temperature is not None and temperature != 1.0:
25
        processor_list.append(TemperatureLogitsWarper(temperature))
26
    if top_k is not None and top_k != 0:
27
        processor_list.append(TopKLogitsWarper(top_k))
28
    if top_p is not None and top_p < 1.0:
29
        processor_list.append(TopPLogitsWarper(top_p))
30
    return processor_list
31

32

33
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
34
    if dist.is_initialized() and dist.get_world_size() > 1:
35
        # consider DP
36
        unfinished_sequences = unfinished_sequences.clone()
37
        dist.all_reduce(unfinished_sequences)
38
    return unfinished_sequences.max() == 0
39

40

41
def _sample(
42
    model: Actor,
43
    input_ids: torch.Tensor,
44
    max_length: int,
45
    early_stopping: bool = False,
46
    eos_token_id: Optional[int] = None,
47
    pad_token_id: Optional[int] = None,
48
    top_k: Optional[int] = None,
49
    top_p: Optional[float] = None,
50
    temperature: Optional[float] = None,
51
    prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
52
    update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
53
    **model_kwargs,
54
) -> torch.Tensor:
55
    if input_ids.size(1) >= max_length:
56
        return input_ids
57

58
    logits_processor = _prepare_logits_processor(top_k, top_p, temperature)
59
    unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
60

61
    for _ in range(input_ids.size(1), max_length):
62
        model_inputs = (
63
            prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {"input_ids": input_ids}
64
        )
65
        outputs = model(**model_inputs)
66

67
        # NOTE: this is correct only in left padding mode
68
        next_token_logits = outputs["logits"][:, -1, :]
69
        next_token_logits = logits_processor(input_ids, next_token_logits)
70
        # sample
71
        probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
72
        next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
73

74
        # finished sentences should have their next token be a padding token
75
        if eos_token_id is not None:
76
            assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
77
            next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
78

79
        # update generated ids, model inputs for next step
80
        input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
81
        if update_model_kwargs_fn is not None:
82
            model_kwargs = update_model_kwargs_fn(outputs, model_kwargs)
83

84
        # if eos_token was found in one sentence, set sentence to finished
85
        if eos_token_id is not None:
86
            unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
87

88
        # stop when each sentence is finished if early_stopping=True
89
        if early_stopping and _is_sequence_finished(unfinished_sequences):
90
            break
91

92
    return input_ids
93

94

95
@torch.no_grad()
96
def generate(
97
    model: Actor,
98
    input_ids: torch.Tensor,
99
    tokenizer: PreTrainedTokenizer,
100
    max_length: int,
101
    num_beams: int = 1,
102
    do_sample: bool = True,
103
    early_stopping: bool = False,
104
    top_k: Optional[int] = None,
105
    top_p: Optional[float] = None,
106
    temperature: Optional[float] = None,
107
    prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
108
    update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
109
    **model_kwargs,
110
) -> torch.Tensor:
111
    """Generate token sequence. The returned sequence is input_ids + generated_tokens.
112

113
    Args:
114
        model (nn.Module): model
115
        input_ids (torch.Tensor): input sequence
116
        max_length (int): max length of the returned sequence
117
        num_beams (int, optional): number of beams. Defaults to 1.
118
        do_sample (bool, optional): whether to do sample. Defaults to True.
119
        early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
120
        top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
121
        top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
122
        temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
123
        prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
124
        update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
125
    """
126
    assert tokenizer.padding_side == "left", "Current generation only supports left padding."
127
    is_greedy_gen_mode = (num_beams == 1) and do_sample is False
128
    is_sample_gen_mode = (num_beams == 1) and do_sample is True
129
    is_beam_gen_mode = (num_beams > 1) and do_sample is False
130
    if is_greedy_gen_mode:
131
        # run greedy search
132
        raise NotImplementedError
133
    elif is_sample_gen_mode:
134
        # run sample
135
        return _sample(
136
            model,
137
            input_ids,
138
            max_length,
139
            early_stopping=early_stopping,
140
            eos_token_id=tokenizer.eos_token_id,
141
            pad_token_id=tokenizer.pad_token_id,
142
            top_k=top_k,
143
            top_p=top_p,
144
            temperature=temperature,
145
            prepare_inputs_fn=prepare_inputs_fn,
146
            update_model_kwargs_fn=update_model_kwargs_fn,
147
            **model_kwargs,
148
        )
149
    elif is_beam_gen_mode:
150
        raise NotImplementedError
151
    else:
152
        raise ValueError("Unsupported generation mode")
153

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.