llm-app
126 строк · 5.2 Кб
1from llm_app.model_wrappers.base import BaseModel2
3
4class HFPipelineTask(BaseModel):5def __init__(self, model, device="cpu", **kwargs):6"""7A wrapper class for Hugging Face's `Pipeline` class.
8
9The `pipeline` function from Hugging Face is a utility factory method that creates
10a Pipeline to handle different tasks.
11It supports tasks like text classification, translation, summarization, and many more.
12
13This wrapper class simplifies the process of initializing the pipeline and allows the user
14to easily change the underlying model used for computations.
15
16Parameters:
17-----------
18model : str, required
19The model identifier from Hugging Face's model hub.
20device : str, default='cpu'
21The device where the computations will be performed.
22Supports 'cpu' or 'gpu'. Default is 'cpu'.
23**kwargs : optional
24Additional arguments form HF.
25Please check out https://huggingface.co/docs/transformers/main/main_classes/pipelines
26for more information on the models and available arguments.
27
28Attributes:
29-----------
30pipeline : transformers.Pipeline
31The Hugging Face pipeline object.
32tokenizer : transformers.PreTrainedTokenizer
33The tokenizer associated with the pipeline.
34
35Example:
36--------
37>>> pipe = HFPipelineTask('gpt2')
38>>> result = pipe('Hello world')
39"""
40from transformers import pipeline41
42super().__init__(**kwargs)43self.pipeline = pipeline(model=model, device=device)44self.tokenizer = self.pipeline.tokenizer45
46def crop_to_max_length(self, input_string, max_length=500):47tokens = self.tokenizer.tokenize(input_string)48if len(tokens) > max_length:49tokens = tokens[:max_length]50return self.tokenizer.convert_tokens_to_string(tokens)51
52
53class HFFeatureExtractionTask(HFPipelineTask):54def __init__(self, model, device="cpu", max_length=500, **kwargs):55super().__init__(model, device=device, **kwargs)56self.max_length = max_length57
58def __call__(self, text, **kwargs):59"""60This method computes feature embeddings for the given text.
61HuggingFace Feature extraction models return embeddings per token.
62To get the embedding vector of a text, we simply take the average.
63
64Args:
65text (str): The text for which we compute the embedding.
66**kwargs: Additional arguments to be passed to the pipeline.
67
68Returns:
69List[float]: The average feature embeddings computed by the model.
70"""
71
72text = self.crop_to_max_length(text, max_length=self.max_length)73# This will return a list of lists (one list for each word in the text)74embedding = self.pipeline(text, **kwargs)[0]75
76# For simplicity, we'll just average all word vectors to get a sentence embedding77avg_embedding = [sum(col) / len(col) for col in zip(*embedding)]78
79return avg_embedding80
81
82class HFTextGenerationTask(HFPipelineTask):83def __init__(84self, model, device="cpu", max_prompt_length=500, max_new_tokens=500, **kwargs85):86super().__init__(model, device=device, **kwargs)87self.max_prompt_length = max_prompt_length88self.max_new_tokens = max_new_tokens89
90def __call__(self, text, **kwargs):91"""92Run the model to complete the text.
93Args:
94text (str): prompt to complete.
95return_full_text (bool, optional, defaults to True):
96If True, returns the full text, if False, only added text is returned.
97Only significant if return_text is True.
98clean_up_tokenization_spaces (bool, optional, defaults to False):
99If True, removes extra spaces in text output.
100prefix (str, optional): Adds prefix to prompt.
101handle_long_generation (str, optional): By default, doesn't handle long generation.
102Provides strategies to address this based on your use case:
103None: Does nothing special
104"hole": Truncates left of input, leaving a gap for generation.
105Might truncate a lot of the prompt, not suitable when generation exceeds model capacity.
106Other arguments from transformers.TextGenerationPipeline.__call__ are supported as well. Link:
107https://huggingface.co/docs/transformers/main/main_classes/pipelines#transformers.TextGenerationPipeline.__call__
108
109"""
110text = self.crop_to_max_length(text, self.max_prompt_length)111
112max_new_tokens = kwargs.pop("max_new_tokens", self.max_new_tokens)113
114messages = [115{116"role": "system",117"content": "You are a helpful virtual assistant that only responds in english clearly and precisely.",118},119{"role": "user", "content": text},120]121prompt = self.tokenizer.apply_chat_template(122messages, tokenize=False, add_generation_prompt=True123)124
125output = self.pipeline(prompt, max_new_tokens=max_new_tokens, **kwargs)126return output[0]["generated_text"]127