llama-index

Форк
0
131 строка · 4.7 Кб
1
"""Cross Encoder Finetuning Engine."""
2

3
from typing import Any, List, Optional, Union
4

5
from llama_index.legacy.finetuning.cross_encoders.dataset_gen import (
6
    CrossEncoderFinetuningDatasetSample,
7
)
8
from llama_index.legacy.finetuning.types import BaseCrossEncoderFinetuningEngine
9
from llama_index.legacy.postprocessor import SentenceTransformerRerank
10

11

12
class CrossEncoderFinetuneEngine(BaseCrossEncoderFinetuningEngine):
13
    """Cross-Encoders Finetune Engine."""
14

15
    def __init__(
16
        self,
17
        dataset: List[CrossEncoderFinetuningDatasetSample],
18
        model_id: str = "cross-encoder/ms-marco-MiniLM-L-12-v2",
19
        model_output_path: str = "exp_finetune",
20
        batch_size: int = 10,
21
        val_dataset: Union[List[CrossEncoderFinetuningDatasetSample], None] = None,
22
        loss: Union[Any, None] = None,
23
        epochs: int = 2,
24
        show_progress_bar: bool = True,
25
        evaluation_steps: int = 50,
26
    ) -> None:
27
        """Init params."""
28
        try:
29
            from sentence_transformers import InputExample
30
            from sentence_transformers.cross_encoder import CrossEncoder
31
            from torch.utils.data import DataLoader
32
        except ImportError:
33
            raise ImportError(
34
                "Cannot import sentence-transformers package,",
35
                "please `pip install sentence-transformers`",
36
            )
37

38
        self.dataset = dataset
39

40
        self.model_id = model_id
41
        self.model_output_path = model_output_path
42
        self.model = CrossEncoder(self.model_id, num_labels=1)
43

44
        examples: Any = []
45
        for sample in dataset:
46
            query = sample.query
47
            text = sample.context
48
            score = sample.score
49
            example = InputExample(texts=[query, text], label=score)
50
            examples.append(example)
51
        self.examples = examples
52

53
        self.loader: DataLoader = DataLoader(examples, batch_size=batch_size)
54

55
        # define evaluator
56
        from sentence_transformers.cross_encoder.evaluation import (
57
            CEBinaryClassificationEvaluator,
58
        )
59

60
        # TODO: also add support for CERerankingEvaluator
61
        evaluator: Optional[CEBinaryClassificationEvaluator] = None
62

63
        if val_dataset is not None:
64
            dev_samples = []
65

66
            for val_sample in val_dataset:
67
                val_query = val_sample.query
68
                val_text = val_sample.context
69
                val_score = val_sample.score
70
                val_example = InputExample(texts=[val_query, val_text], label=val_score)
71
                dev_samples.append(val_example)
72

73
            evaluator = CEBinaryClassificationEvaluator.from_input_examples(dev_samples)
74

75
        self.evaluator = evaluator
76

77
        # define loss
78
        self.loss = loss
79

80
        self.epochs = epochs
81
        self.show_progress_bar = show_progress_bar
82
        self.evaluation_steps = evaluation_steps
83
        self.warmup_steps = int(len(self.loader) * epochs * 0.1)
84

85
    def finetune(self, **train_kwargs: Any) -> None:
86
        """Finetune model."""
87
        self.model.fit(
88
            train_dataloader=self.loader,
89
            epochs=self.epochs,
90
            warmup_steps=self.warmup_steps,
91
            output_path=self.model_output_path,
92
            show_progress_bar=self.show_progress_bar,
93
            evaluator=self.evaluator,
94
            evaluation_steps=self.evaluation_steps,
95
        )
96
        # CrossEncoder library's fit function does not save model when evaluator is None
97
        # https://github.com/UKPLab/sentence-transformers/issues/2324
98
        if self.evaluator is None:
99
            self.model.save(self.model_output_path)
100
        else:
101
            pass
102

103
    def push_to_hub(self, repo_id: Any = None) -> None:
104
        """
105
        Saves the model and tokenizer to HuggingFace hub.
106
        """
107
        if repo_id is not None:
108
            try:
109
                self.model.model.push_to_hub(repo_id=repo_id)
110
                self.model.tokenizer.push_to_hub(repo_id=repo_id)
111

112
            except ValueError:
113
                raise ValueError(
114
                    "HuggingFace CLI/Hub login not "
115
                    "completed provide token to login using"
116
                    "huggingface_hub.login() see this "
117
                    "https://huggingface.co/docs/transformers/model_sharing#share-a-model"
118
                )
119
        else:
120
            raise ValueError("No value provided for repo_id")
121

122
    def get_finetuned_model(
123
        self, model_name: str, top_n: int = 3
124
    ) -> SentenceTransformerRerank:
125
        """
126
        Loads the model from huggingface hub as re-ranker.
127

128
        :param repo_id: Huggingface Hub repo from where you want to load the model
129
        :param top_n: The value of nodes the re-ranker should filter
130
        """
131
        return SentenceTransformerRerank(model=model_name, top_n=top_n)
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.