LLM-FineTuning-Large-Language-Models
/
validation.py
234 строки · 15.8 Кб
1from collections import defaultdict2
3import numpy as np4import pandas as pd5from sklearn.metrics import accuracy_score6
7import torch8
9from const import *10from train import *11from utils import *12
13import warnings14warnings.filterwarnings("ignore")15
16def inference(dl, model):17""" Performs inference on a given dataset using the provided Named Entity Recognition (NER) model.18
19This function processes the input text in batches and outputs the predicted named entity types for each token in the original text. The function ensures that each word is assigned a predicted label only once, even when it appears in multiple overlapping chunks.
20
21Args:
22dl (DataLoader): A DataLoader object that provides an iterable over the given dataset.
23model (nn.Module): A pre-trained NER model for inference.
24
25Returns:
26final_predictions (List[List[str]]): A list of lists containing the predicted named entity types for each token in the original text. """
27
28predictions = defaultdict(list)29seen_words_idx = defaultdict(list)30
31for batch in dl:32ids = batch["input_ids"].to(config['device'])33mask = batch["attention_mask"].to(config['device'])34outputs = model(ids, attention_mask=mask, return_dict=False)35# print('batch ', batch) # see output format below in comment36
37del ids, mask38
39batch_preds = torch.argmax(outputs[0], axis=-1).cpu().numpy()40# outputs[0] tensor contains the logits (raw output values) for each token, representing the probability distribution over all possible label classes.41# torch.argmax() function is used to find the index of the highest value (the most probable class) along the last dimension (axis=-1) of the logits tensor. this index corresponds to the predicted label for each token.42# It is needed because the model's output is in the form of logits, which are not directly interpretable. By using torch.argmax(), we can obtain the predicted label indices that are easily understandable and can be mapped back to their corresponding string labels for evaluation or further processing.43
44# Go over each prediction, getting the text_id reference45# batch_preds: contains the predicted label indices for each token in the batch.46# batch['overflow_to_sample_mapping'].tolist(): This list maps the predicted labels back to their original text_ids. It is necessary because the input text may have been split into multiple chunks (due to token limits or other reasons), and we need to keep track of which part of the original text each prediction belongs to.47# zip(): This function combines batch_preds and batch['overflow_to_sample_mapping'].tolist() element-wise, creating pairs of (chunk_preds, text_id).48for k, (chunk_preds, text_id) in enumerate(zip(batch_preds, batch['overflow_to_sample_mapping'].tolist())):49# print('chunk_preds ', chunk_preds)50# its just a list like [2 2 2 2 2 2... 2 2]51# print('batch.keys() ', batch.keys())52# => batch.keys() dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping', 'labels', 'wids'])53
54# The word_ids are absolute references in the original text55word_ids = batch['wids'][k].numpy()56
57# Map from ids to labels58chunk_preds = [IDS_TO_LABELS[i] for i in chunk_preds]59# This line is needed because the model's output is in the form of label indices, which are not directly interpretable. By mapping these indices to their corresponding string labels, we can easily understand the predicted named entity types for each token and evaluate the model's performance or use the predictions for further processing.60
61for idx, word_idx in enumerate(word_ids):62# If the word index is -1, the loop should do nothing and continues to the next word index. This is because -1 represents a padding token or a special token that doesn't have a corresponding word in the original text.63if word_idx == -1:64pass65# ensure, that each word in the original text is assigned a predicted label only once, even when it appears in multiple overlapping chunks.66elif word_idx not in seen_words_idx[text_id]:67# This checks if the current word index has not been processed for the given text_id. This is done to avoid processing the same word multiple times when it appears in overlapping chunks.68# Add predictions if the word doesn't have a prediction from a previous chunk69predictions[text_id].append(chunk_preds[idx])70# Also add the current word index to the seen_words_idx dictionary,71seen_words_idx[text_id].append(word_idx)72
73# print('predictions ', predictions)74# predictions defaultdict(<class 'list'>, {0: ['I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement'.... ]})75
76# list comprehension that iterates through each sorted key and retrieves the corresponding value (i.e., the predicted labels) from the predictions dictionary. sorts the keys in ascending order, ensuring that the final predictions are in the same order as the original texts.77final_predictions = [predictions[k] for k in sorted(predictions.keys())]78return final_predictions79
80""" ### Output of `print('batch ', batch)` in `inference()` method
81
82batch {'input_ids': tensor([[ 0, 4129, 6909, 8, 19181, 6893, 198, 47, 328, 404,
83...]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
841, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...]]), 'overflow_to_sample_mapping': tensor([0]), 'labels': tensor([[-100, 0, 0, 0, 0, 1, 2, ... -100]]), 'wids': tensor([[ -1, 0, 0, 1, 2, 3, 4, 5, 5, 6, 7, 8, 9, 10,
8511, 12, 13, 14 ... -1]])}
86
87"""
88
89
90def get_predictions(df, dl, model):91"""92This function generates entity predictions for a given dataframe using a specified model and dataloader.
93
94Args:
95df (pd.DataFrame): A DataFrame containing input data for which predictions will be generated.
96dl (torch.utils.data.DataLoader): A DataLoader object to batch and process the input data.
97model (torch.nn.Module): A PyTorch model for Named Entity Recognition (NER) to generate predictions.
98
99Returns:
100df_pred (pd.DataFrame): A DataFrame containing entity predictions for each input in the original DataFrame.
101The columns in the output DataFrame are:
102- 'id': The input ID from the original DataFrame.
103- 'class': The predicted entity class.
104- 'predictionstring': A string of space-separated indices corresponding to the
105predicted entity in the input.
106"""
107predicted_labels = inference(dl, model)108# initializes an empty list called final_preds that will store the final entity predictions for each input in the DataFrame.109final_preds = []110
111for i in range(len(df)):112#extract the ID of the current input from the DataFrame.113idx = df.id.values[i]114#get the predicted label corresponding to the current input115pred = predicted_labels[i]116j = 0117
118while j < len(pred):119#assign the label at index j from the predicted labels pred to the variable cls.120cls = pred[j]121#check if the current label cls is an 'O' (representing no entity). If it is, the code will skip122if cls == 'O': pass123# The approach I am following here is that, during inference our model will make predictions for each subword token. Some single words consist of multiple subword tokens. In the code below, we use a word's first subword token prediction as the label for the entire word. Other alternatives that could have been tried are like averaging all subword predictions or taking `B` labels before `I` labels etc.124
125#If the current label is not 'O', this line replaces 'B' with 'I' in the label. This is done to consider 'B' and 'I' tags as the same, simplifying the task of extracting the entity.126#The reason for replacing 'B' with 'I' in this line is to treat both 'B' and 'I' tags as the same for the purpose of extracting entities. By doing this, we can easily identify and extract the entire entity by looking for a continuous sequence of the same 'I' tag. This simplification makes it easier to find the start and end indices of the entities in the input text.127else: cls = cls.replace('B','I')128# the purpose of below block of code is to find the end index of a continuous entity in the predicted labels.129end = j + 1130#above line is done because we want to start looking for the end of the entity from the next position in the predicted labels.131
132#start a while-loop that iterates through the predicted labels, starting from index end.133# The loop continues as long as the following conditions are met:134# end is less than the length of pred (ensuring we don't exceed the boundaries of the list)135#The label at index end (pred[end]) is equal to the current label cls (indicating that the current entity is still continuing)136#end += 1 - This line increments the value of end by 1 in each iteration of the while-loop. This ensures that the loop progresses through the predicted labels until it finds the end of the continuous entity or reaches the end of the list.137#The significance of this block is to identify the continuous sequence of the same entity label in the predicted labels, which represents a single entity in the input text. By finding the start and end indices of this continuous sequence, we can effectively extract the complete entity from the input text.138while end < len(pred) and pred[end] == cls:139end += 1140# j > 0 for only testing the whole nb with just few samples of the entire .txt data141# Else I will get axis mismatch problem from pandas142# but for full training data its j > 7143if cls != 'O' and cls != '' and end - j > 0:144final_preds.append((idx, cls.replace('I-',''),145' '.join(map(str, list(range(j, end))))))146j = end # at last increment or update the j by pushing it forward to the 'end' position147"""148If the current label cls is not an 'O' (representing no entity) + not an empty string + The length of the continuous entity sequence is greater than 7 => Then
149
150If the above conditions are met, then append a tuple to the final_preds list with the following elements:
151
152idx: The input ID from the original DataFrame.
153
154cls.replace('I-', ''): The predicted entity class, with the 'I-' prefix removed. (Note, earlier we have already replaced all 'B' with 'I', so here I will be left with 'I-' only)
155
156' '.join(map(str, list(range(j, end)))): A string of space-separated indices corresponding to the predicted entity in the input, created by converting the list of indices from j to end (exclusive) to a string representation.
157
158end - j > 7 This condition is a filter to only include entities that span more than 7 tokens in the input text. Depending on the problem at hand, this threshold may be adjusted or removed to include entities of any length.
159
160The significance of this block is to extract the entity information from the predicted labels, filter entities based on the length threshold, and store the final entity predictions in a structured format.
161"""
162
163df_pred = pd.DataFrame(final_preds)164df_pred.columns = ['id','class','predictionstring']165df_pred.head(2)166return df_pred167
168
169"""
170Explanations of the below line
171
172' '.join(map(str, list(range(j, end))))
173
174
175range(j, end): This part of the code creates a range of integers starting from j (inclusive) to end (exclusive). j is the starting index of the current entity, and end is the index immediately after the last token of the current entity.
176
177
178list(range(j, end)): This part converts the range object into a list of integers. The list contains all the indices that correspond to the tokens of the predicted entity in the input text.
179
180
181map(str, list(range(j, end))): This part applies the str function to each element of the list using the map() function. The purpose of this step is to convert the list of integers into a list of strings. This is necessary because we want to join these indices into a single string in the next step.
182
183' '.join(map(str, list(range(j, end)))): Finally, this part uses the join() method of the string ' ' (a single space) to concatenate all the string elements in the list, separated by a space. This results in a single string with space-separated indices corresponding to the predicted entity in the input text.
184
185
186The reason for using this line in the get_predictions() method is to create a compact and human-readable representation of the token indices for the predicted entity in the input text. This string representation is then added to the final_preds list as part of the tuple storing the entity prediction information,
187
188
189"""
190
191def validate(model, df_train_org, df_val, dataloader_val, epoch, unique_valid_id_list ):192""" Validates the performance of a given NER model on a validation dataset.193
194This function computes the F1-score for each class and the overall F1-score for the model on the validation dataset. It prints the F1-score per class and the overall F1-score.
195
196Args:
197model (nn.Module): A pre-trained NER model for validation.
198df_train_org (pd.DataFrame): The original training DataFrame, containing both training and validation samples.
199df_val (pd.DataFrame): The validation DataFrame, containing a subset of the columns of df_train_org.
200dataloader_val (DataLoader): A DataLoader object that provides an iterable over the validation dataset.
201epoch (int): The current epoch number in the training process.
202unique_valid_id_list (List[int]): A list of unique validation sample IDs. """
203
204time_start = time.time()205
206# Put model in eval model207model.eval()208
209df_valid = df_train_org.loc[df_train_org['id'].isin(unique_valid_id_list)]210
211out_of_fold = get_predictions(df_val, dataloader_val, model)212
213f1s = []214classes = out_of_fold['class'].unique()215
216epoch_prefix = f"[Epoch {epoch+1:2d} / {config['epochs']:2d}]"217print(f"{epoch_prefix} Validation F1 scores")218
219f1s_log = {}220for c in classes:221# creates a new DataFrame pred_df by filtering the out_of_fold DataFrame for rows where the 'class' column matches the current class c. The .copy() method is used to create a copy of the filtered data, ensuring that any changes made to pred_df won't affect the original out_of_fold DataFrame. out_of_fold contains the model's predictions for the validation dataset. By creating pred_df, the function can focus on the performance of the model for the specific class c when calculating the F1-score.222pred_df = out_of_fold.loc[out_of_fold['class']==c].copy()223gt_df = df_valid.loc[df_valid['discourse_type']==c].copy()224f1 = compute_macro_f1_score(pred_df, gt_df)225print(f"{epoch_prefix} * {c:<10}: {f1:4f}")226f1s.append(f1)227f1s_log[f'F1 {c}'] = f1228
229elapsed = time.time() - time_start230print(epoch_prefix)231print(f'{epoch_prefix} Overall Validation F1: {np.mean(f1s):.4f} [{elapsed:.2f} secs]')232print(epoch_prefix)233f1s_log['Overall F1'] = np.mean(f1s)234wandb.log(f1s_log)