llama-index
128 строк · 4.6 Кб
1import random2from typing import Any, List, Optional, Tuple3
4from llama_index.legacy.bridge.pydantic import BaseModel5from llama_index.legacy.finetuning import EmbeddingQAFinetuneDataset6from llama_index.legacy.indices.query.embedding_utils import get_top_k_embeddings7
8
9class CohereRerankerFinetuneDataset(BaseModel):10"""Class for keeping track of CohereAI Reranker finetuning training/validation Dataset."""11
12query: str13relevant_passages: List[str]14hard_negatives: Any15
16def to_jsonl(self) -> str:17"""Convert the BaseModel instance to a JSONL string."""18return self.json() + "\n"19
20
21def generate_embeddings(embed_model: Any, text: str) -> List[float]:22# Generate embeddings for a list of texts23return embed_model.get_text_embedding(text)24
25
26def generate_hard_negatives(27queries: List[str],28relevant_contexts: List[str],29embed_model: Optional[Any],30num_negatives: int = 5,31method: str = "random",32) -> Any:33hard_negatives = []34
35if method == "cosine_similarity":36query_embeddings = [37generate_embeddings(embed_model, query) for query in queries38]39relevant_contexts_embeddings = [40generate_embeddings(embed_model, context) for context in relevant_contexts41]42
43for query_index, _ in enumerate(queries):44if method == "random":45# Exclude the correct context46potential_negatives = (47relevant_contexts[:query_index] + relevant_contexts[query_index + 1 :]48)49# Randomly select hard negatives50hard_negatives.append(51random.sample(52potential_negatives, min(num_negatives, len(potential_negatives))53)54)55
56elif method == "cosine_similarity":57query_embedding = query_embeddings[query_index]58# Use get_top_k_embeddings to select num_negatives closest but not correct contexts59_, relevant_contexts_indices = get_top_k_embeddings(60query_embedding,61relevant_contexts_embeddings,62)63
64# Filter out the correct context to only include hard negatives65hard_negative_indices = [66idx for idx in relevant_contexts_indices if idx != query_index67][:num_negatives]68
69# Map indices to actual contexts to get the hard negatives70hard_negatives_for_query = [71relevant_contexts[idx] for idx in hard_negative_indices72]73
74hard_negatives.append(hard_negatives_for_query)75return hard_negatives76
77
78def get_query_context_lists(79query_context_pairs: EmbeddingQAFinetuneDataset,80) -> Tuple[List[str], List[str]]:81queries = []82relevant_contexts = []83
84# 'query_context_pairs' is an object with 'queries', 'corpus', and 'relevant_docs' attributes85for query_id, query in query_context_pairs.queries.items():86# Get the first relevant document ID for the current query87relevant_doc_id = query_context_pairs.relevant_docs[query_id][0]88# Get the relevant context using the relevant document ID89relevant_context = query_context_pairs.corpus[relevant_doc_id]90# Append the query and the relevant context to their respective lists91queries.append(query)92relevant_contexts.append(relevant_context)93
94return queries, relevant_contexts95
96
97def generate_cohere_reranker_finetuning_dataset(98query_context_pairs: EmbeddingQAFinetuneDataset,99num_negatives: int = 0,100top_k_dissimilar: int = 100,101hard_negatives_gen_method: str = "random",102finetune_dataset_file_name: str = "train.jsonl",103embed_model: Optional[Any] = None,104) -> Any:105queries, relevant_contexts = get_query_context_lists(query_context_pairs)106
107if num_negatives:108hard_negatives = generate_hard_negatives(109queries,110relevant_contexts,111embed_model,112num_negatives,113hard_negatives_gen_method,114)115else:116hard_negatives = [[] for _ in queries]117# Open the file in write mode118with open(finetune_dataset_file_name, "w") as outfile:119# Iterate over the lists simultaneously using zip120for query, context, hard_negative in zip(121queries, relevant_contexts, hard_negatives122):123# Instantiate a CohereRerankerFinetuneDataset object for the current entry124entry = CohereRerankerFinetuneDataset(125query=query, relevant_passages=[context], hard_negatives=hard_negative126)127# Write the JSONL string to the file128outfile.write(entry.to_jsonl())129