llm-app

Форк
0
126 строк · 5.2 Кб
1
from llm_app.model_wrappers.base import BaseModel
2

3

4
class HFPipelineTask(BaseModel):
5
    def __init__(self, model, device="cpu", **kwargs):
6
        """
7
        A wrapper class for Hugging Face's `Pipeline` class.
8

9
        The `pipeline` function from Hugging Face is a utility factory method that creates
10
        a Pipeline to handle different tasks.
11
        It supports tasks like text classification, translation, summarization, and many more.
12

13
        This wrapper class simplifies the process of initializing the pipeline and allows the user
14
        to easily change the underlying model used for computations.
15

16
        Parameters:
17
        -----------
18
        model : str, required
19
            The model identifier from Hugging Face's model hub.
20
        device : str, default='cpu'
21
            The device where the computations will be performed.
22
            Supports 'cpu' or 'gpu'. Default is 'cpu'.
23
        **kwargs : optional
24
            Additional arguments form HF.
25
            Please check out https://huggingface.co/docs/transformers/main/main_classes/pipelines
26
            for more information on the models and available arguments.
27

28
        Attributes:
29
        -----------
30
        pipeline : transformers.Pipeline
31
            The Hugging Face pipeline object.
32
        tokenizer : transformers.PreTrainedTokenizer
33
            The tokenizer associated with the pipeline.
34

35
        Example:
36
        --------
37
        >>> pipe = HFPipelineTask('gpt2')
38
        >>> result = pipe('Hello world')
39
        """
40
        from transformers import pipeline
41

42
        super().__init__(**kwargs)
43
        self.pipeline = pipeline(model=model, device=device)
44
        self.tokenizer = self.pipeline.tokenizer
45

46
    def crop_to_max_length(self, input_string, max_length=500):
47
        tokens = self.tokenizer.tokenize(input_string)
48
        if len(tokens) > max_length:
49
            tokens = tokens[:max_length]
50
        return self.tokenizer.convert_tokens_to_string(tokens)
51

52

53
class HFFeatureExtractionTask(HFPipelineTask):
54
    def __init__(self, model, device="cpu", max_length=500, **kwargs):
55
        super().__init__(model, device=device, **kwargs)
56
        self.max_length = max_length
57

58
    def __call__(self, text, **kwargs):
59
        """
60
        This method computes feature embeddings for the given text.
61
        HuggingFace Feature extraction models return embeddings per token.
62
        To get the embedding vector of a text, we simply take the average.
63

64
        Args:
65
            text (str): The text for which we compute the embedding.
66
            **kwargs: Additional arguments to be passed to the pipeline.
67

68
        Returns:
69
            List[float]: The average feature embeddings computed by the model.
70
        """
71

72
        text = self.crop_to_max_length(text, max_length=self.max_length)
73
        # This will return a list of lists (one list for each word in the text)
74
        embedding = self.pipeline(text, **kwargs)[0]
75

76
        # For simplicity, we'll just average all word vectors to get a sentence embedding
77
        avg_embedding = [sum(col) / len(col) for col in zip(*embedding)]
78

79
        return avg_embedding
80

81

82
class HFTextGenerationTask(HFPipelineTask):
83
    def __init__(
84
        self, model, device="cpu", max_prompt_length=500, max_new_tokens=500, **kwargs
85
    ):
86
        super().__init__(model, device=device, **kwargs)
87
        self.max_prompt_length = max_prompt_length
88
        self.max_new_tokens = max_new_tokens
89

90
    def __call__(self, text, **kwargs):
91
        """
92
        Run the model to complete the text.
93
        Args:
94
            text (str): prompt to complete.
95
            return_full_text (bool, optional, defaults to True):
96
                If True, returns the full text, if False, only added text is returned.
97
                Only significant if return_text is True.
98
            clean_up_tokenization_spaces (bool, optional, defaults to False):
99
                If True, removes extra spaces in text output.
100
            prefix (str, optional): Adds prefix to prompt.
101
            handle_long_generation (str, optional): By default, doesn't handle long generation.
102
                Provides strategies to address this based on your use case:
103
                    None: Does nothing special
104
                    "hole": Truncates left of input, leaving a gap for generation.
105
                        Might truncate a lot of the prompt, not suitable when generation exceeds model capacity.
106
            Other arguments from transformers.TextGenerationPipeline.__call__ are supported as well. Link:
107
            https://huggingface.co/docs/transformers/main/main_classes/pipelines#transformers.TextGenerationPipeline.__call__
108

109
        """
110
        text = self.crop_to_max_length(text, self.max_prompt_length)
111

112
        max_new_tokens = kwargs.pop("max_new_tokens", self.max_new_tokens)
113

114
        messages = [
115
            {
116
                "role": "system",
117
                "content": "You are a helpful virtual assistant that only responds in english clearly and precisely.",
118
            },
119
            {"role": "user", "content": text},
120
        ]
121
        prompt = self.tokenizer.apply_chat_template(
122
            messages, tokenize=False, add_generation_prompt=True
123
        )
124

125
        output = self.pipeline(prompt, max_new_tokens=max_new_tokens, **kwargs)
126
        return output[0]["generated_text"]
127

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.