llama-index

Форк
0
128 строк · 4.6 Кб
1
import random
2
from typing import Any, List, Optional, Tuple
3

4
from llama_index.legacy.bridge.pydantic import BaseModel
5
from llama_index.legacy.finetuning import EmbeddingQAFinetuneDataset
6
from llama_index.legacy.indices.query.embedding_utils import get_top_k_embeddings
7

8

9
class CohereRerankerFinetuneDataset(BaseModel):
10
    """Class for keeping track of CohereAI Reranker finetuning training/validation Dataset."""
11

12
    query: str
13
    relevant_passages: List[str]
14
    hard_negatives: Any
15

16
    def to_jsonl(self) -> str:
17
        """Convert the BaseModel instance to a JSONL string."""
18
        return self.json() + "\n"
19

20

21
def generate_embeddings(embed_model: Any, text: str) -> List[float]:
22
    # Generate embeddings for a list of texts
23
    return embed_model.get_text_embedding(text)
24

25

26
def generate_hard_negatives(
27
    queries: List[str],
28
    relevant_contexts: List[str],
29
    embed_model: Optional[Any],
30
    num_negatives: int = 5,
31
    method: str = "random",
32
) -> Any:
33
    hard_negatives = []
34

35
    if method == "cosine_similarity":
36
        query_embeddings = [
37
            generate_embeddings(embed_model, query) for query in queries
38
        ]
39
        relevant_contexts_embeddings = [
40
            generate_embeddings(embed_model, context) for context in relevant_contexts
41
        ]
42

43
    for query_index, _ in enumerate(queries):
44
        if method == "random":
45
            # Exclude the correct context
46
            potential_negatives = (
47
                relevant_contexts[:query_index] + relevant_contexts[query_index + 1 :]
48
            )
49
            # Randomly select hard negatives
50
            hard_negatives.append(
51
                random.sample(
52
                    potential_negatives, min(num_negatives, len(potential_negatives))
53
                )
54
            )
55

56
        elif method == "cosine_similarity":
57
            query_embedding = query_embeddings[query_index]
58
            # Use get_top_k_embeddings to select num_negatives closest but not correct contexts
59
            _, relevant_contexts_indices = get_top_k_embeddings(
60
                query_embedding,
61
                relevant_contexts_embeddings,
62
            )
63

64
            # Filter out the correct context to only include hard negatives
65
            hard_negative_indices = [
66
                idx for idx in relevant_contexts_indices if idx != query_index
67
            ][:num_negatives]
68

69
            # Map indices to actual contexts to get the hard negatives
70
            hard_negatives_for_query = [
71
                relevant_contexts[idx] for idx in hard_negative_indices
72
            ]
73

74
            hard_negatives.append(hard_negatives_for_query)
75
    return hard_negatives
76

77

78
def get_query_context_lists(
79
    query_context_pairs: EmbeddingQAFinetuneDataset,
80
) -> Tuple[List[str], List[str]]:
81
    queries = []
82
    relevant_contexts = []
83

84
    # 'query_context_pairs' is an object with 'queries', 'corpus', and 'relevant_docs' attributes
85
    for query_id, query in query_context_pairs.queries.items():
86
        # Get the first relevant document ID for the current query
87
        relevant_doc_id = query_context_pairs.relevant_docs[query_id][0]
88
        # Get the relevant context using the relevant document ID
89
        relevant_context = query_context_pairs.corpus[relevant_doc_id]
90
        # Append the query and the relevant context to their respective lists
91
        queries.append(query)
92
        relevant_contexts.append(relevant_context)
93

94
    return queries, relevant_contexts
95

96

97
def generate_cohere_reranker_finetuning_dataset(
98
    query_context_pairs: EmbeddingQAFinetuneDataset,
99
    num_negatives: int = 0,
100
    top_k_dissimilar: int = 100,
101
    hard_negatives_gen_method: str = "random",
102
    finetune_dataset_file_name: str = "train.jsonl",
103
    embed_model: Optional[Any] = None,
104
) -> Any:
105
    queries, relevant_contexts = get_query_context_lists(query_context_pairs)
106

107
    if num_negatives:
108
        hard_negatives = generate_hard_negatives(
109
            queries,
110
            relevant_contexts,
111
            embed_model,
112
            num_negatives,
113
            hard_negatives_gen_method,
114
        )
115
    else:
116
        hard_negatives = [[] for _ in queries]
117
    # Open the file in write mode
118
    with open(finetune_dataset_file_name, "w") as outfile:
119
        # Iterate over the lists simultaneously using zip
120
        for query, context, hard_negative in zip(
121
            queries, relevant_contexts, hard_negatives
122
        ):
123
            # Instantiate a CohereRerankerFinetuneDataset object for the current entry
124
            entry = CohereRerankerFinetuneDataset(
125
                query=query, relevant_passages=[context], hard_negatives=hard_negative
126
            )
127
            # Write the JSONL string to the file
128
            outfile.write(entry.to_jsonl())
129

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.