llama-index

Форк
0
95 строк · 2.5 Кб
1
import time
2
import uuid
3
from typing import Any, Dict, Optional
4

5
import numpy as np
6

7

8
def parse_input(
9
    input_text: str, tokenizer: Any, end_id: int, remove_input_padding: bool
10
) -> Any:
11
    try:
12
        import torch
13
    except ImportError:
14
        raise ImportError("nvidia_tensorrt requires `pip install torch`.")
15

16
    input_tokens = []
17

18
    input_tokens.append(tokenizer.encode(input_text, add_special_tokens=False))
19

20
    input_lengths = torch.tensor(
21
        [len(x) for x in input_tokens], dtype=torch.int32, device="cuda"
22
    )
23
    if remove_input_padding:
24
        input_ids = np.concatenate(input_tokens)
25
        input_ids = torch.tensor(input_ids, dtype=torch.int32, device="cuda").unsqueeze(
26
            0
27
        )
28
    else:
29
        input_ids = torch.nested.to_padded_tensor(
30
            torch.nested.nested_tensor(input_tokens, dtype=torch.int32), end_id
31
        ).cuda()
32

33
    return input_ids, input_lengths
34

35

36
def remove_extra_eos_ids(outputs: Any) -> Any:
37
    outputs.reverse()
38
    while outputs and outputs[0] == 2:
39
        outputs.pop(0)
40
    outputs.reverse()
41
    outputs.append(2)
42
    return outputs
43

44

45
def get_output(
46
    output_ids: Any,
47
    input_lengths: Any,
48
    max_output_len: int,
49
    tokenizer: Any,
50
) -> Any:
51
    num_beams = output_ids.size(1)
52
    output_text = ""
53
    outputs = None
54
    for b in range(input_lengths.size(0)):
55
        for beam in range(num_beams):
56
            output_begin = input_lengths[b]
57
            output_end = input_lengths[b] + max_output_len
58
            outputs = output_ids[b][beam][output_begin:output_end].tolist()
59
            outputs = remove_extra_eos_ids(outputs)
60
            output_text = tokenizer.decode(outputs)
61

62
    return output_text, outputs
63

64

65
def generate_completion_dict(
66
    text_str: str, model: Any, model_path: Optional[str]
67
) -> Dict:
68
    """
69
    Generate a dictionary for text completion details.
70

71
    Returns:
72
    dict: A dictionary containing completion details.
73
    """
74
    completion_id: str = f"cmpl-{uuid.uuid4()!s}"
75
    created: int = int(time.time())
76
    model_name: str = model if model is not None else model_path
77
    return {
78
        "id": completion_id,
79
        "object": "text_completion",
80
        "created": created,
81
        "model": model_name,
82
        "choices": [
83
            {
84
                "text": text_str,
85
                "index": 0,
86
                "logprobs": None,
87
                "finish_reason": "stop",
88
            }
89
        ],
90
        "usage": {
91
            "prompt_tokens": None,
92
            "completion_tokens": None,
93
            "total_tokens": None,
94
        },
95
    }
96

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.