LLM-FineTuning-Large-Language-Models

Форк
0
1
from collections import defaultdict
2

3
import numpy as np
4
import pandas as pd
5
from sklearn.metrics import accuracy_score
6

7
import torch
8

9
from const import *
10
from train import *
11
from utils import *
12

13
import warnings
14
warnings.filterwarnings("ignore")
15

16
def inference(dl, model):
17
    """ Performs inference on a given dataset using the provided Named Entity Recognition (NER) model.
18

19
    This function processes the input text in batches and outputs the predicted named entity types for each token in the original text. The function ensures that each word is assigned a predicted label only once, even when it appears in multiple overlapping chunks.
20

21
    Args:
22
    dl (DataLoader): A DataLoader object that provides an iterable over the given dataset.
23
    model (nn.Module): A pre-trained NER model for inference.
24

25
    Returns:
26
    final_predictions (List[List[str]]): A list of lists containing the predicted named entity types for each token in the original text. """
27

28
    predictions = defaultdict(list)
29
    seen_words_idx = defaultdict(list)
30

31
    for batch in dl:
32
        ids = batch["input_ids"].to(config['device'])
33
        mask = batch["attention_mask"].to(config['device'])
34
        outputs = model(ids, attention_mask=mask, return_dict=False)
35
        # print('batch ', batch) # see output format below in comment
36

37
        del ids, mask
38

39
        batch_preds = torch.argmax(outputs[0], axis=-1).cpu().numpy()
40
        # outputs[0] tensor contains the logits (raw output values) for each token, representing the probability distribution over all possible label classes.
41
        # torch.argmax() function is used to find the index of the highest value (the most probable class) along the last dimension (axis=-1) of the logits tensor. this index corresponds to the predicted label for each token.
42
        # It is needed because the model's output is in the form of logits, which are not directly interpretable. By using torch.argmax(), we can obtain the predicted label indices that are easily understandable and can be mapped back to their corresponding string labels for evaluation or further processing.
43

44
        # Go over each prediction, getting the text_id reference
45
        # batch_preds: contains the predicted label indices for each token in the batch.
46
        # batch['overflow_to_sample_mapping'].tolist(): This list maps the predicted labels back to their original text_ids. It is necessary because the input text may have been split into multiple chunks (due to token limits or other reasons), and we need to keep track of which part of the original text each prediction belongs to.
47
        # zip(): This function combines batch_preds and batch['overflow_to_sample_mapping'].tolist() element-wise, creating pairs of (chunk_preds, text_id).
48
        for k, (chunk_preds, text_id) in enumerate(zip(batch_preds, batch['overflow_to_sample_mapping'].tolist())):
49
            # print('chunk_preds ', chunk_preds)
50
            # its just a list like [2 2 2 2 2 2... 2 2]
51
            # print('batch.keys() ', batch.keys())
52
            # => batch.keys()  dict_keys(['input_ids', 'attention_mask', 'overflow_to_sample_mapping', 'labels', 'wids'])
53

54
            # The word_ids are absolute references in the original text
55
            word_ids = batch['wids'][k].numpy()
56

57
            # Map from ids to labels
58
            chunk_preds = [IDS_TO_LABELS[i] for i in chunk_preds]
59
            # This line is needed because the model's output is in the form of label indices, which are not directly interpretable. By mapping these indices to their corresponding string labels, we can easily understand the predicted named entity types for each token and evaluate the model's performance or use the predictions for further processing.
60

61
            for idx, word_idx in enumerate(word_ids):
62
                #  If the word index is -1, the loop should do nothing and continues to the next word index. This is because -1 represents a padding token or a special token that doesn't have a corresponding word in the original text.
63
                if word_idx == -1:
64
                    pass
65
                # ensure, that each word in the original text is assigned a predicted label only once, even when it appears in multiple overlapping chunks.
66
                elif word_idx not in seen_words_idx[text_id]:
67
                    #  This checks if the current word index has not been processed for the given text_id. This is done to avoid processing the same word multiple times when it appears in overlapping chunks.
68
                    # Add predictions if the word doesn't have a prediction from a previous chunk
69
                    predictions[text_id].append(chunk_preds[idx])
70
                    # Also add the current word index to the seen_words_idx dictionary,
71
                    seen_words_idx[text_id].append(word_idx)
72

73
    # print('predictions ', predictions)
74
    # predictions  defaultdict(<class 'list'>, {0: ['I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement', 'I-Concluding Statement'.... ]})
75

76
    #  list comprehension that iterates through each sorted key and retrieves the corresponding value (i.e., the predicted labels) from the predictions dictionary.  sorts the keys in ascending order, ensuring that the final predictions are in the same order as the original texts.
77
    final_predictions = [predictions[k] for k in sorted(predictions.keys())]
78
    return final_predictions
79

80
""" ### Output of `print('batch ', batch)` in `inference()` method
81

82
batch  {'input_ids': tensor([[    0,  4129,  6909,     8, 19181,  6893,   198,    47,   328,   404,
83
             ...]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
84
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...]]), 'overflow_to_sample_mapping': tensor([0]), 'labels': tensor([[-100,    0,    0,    0,    0,    1,    2,  ... -100]]), 'wids': tensor([[ -1,   0,   0,   1,   2,   3,   4,   5,   5,   6,   7,   8,   9,  10,
85
          11,  12,  13,  14 ...  -1]])}
86

87
"""
88

89

90
def get_predictions(df, dl, model):
91
    """
92
    This function generates entity predictions for a given dataframe using a specified model and dataloader.
93

94
    Args:
95
        df (pd.DataFrame): A DataFrame containing input data for which predictions will be generated.
96
        dl (torch.utils.data.DataLoader): A DataLoader object to batch and process the input data.
97
        model (torch.nn.Module): A PyTorch model for Named Entity Recognition (NER) to generate predictions.
98

99
    Returns:
100
        df_pred (pd.DataFrame): A DataFrame containing entity predictions for each input in the original DataFrame.
101
                                The columns in the output DataFrame are:
102
                                - 'id': The input ID from the original DataFrame.
103
                                - 'class': The predicted entity class.
104
                                - 'predictionstring': A string of space-separated indices corresponding to the
105
                                                      predicted entity in the input.
106
    """
107
    predicted_labels = inference(dl, model)
108
    # initializes an empty list called final_preds that will store the final entity predictions for each input in the DataFrame.
109
    final_preds = []
110

111
    for i in range(len(df)):
112
        #extract the ID of the current input from the DataFrame.
113
        idx = df.id.values[i]
114
        #get the predicted label corresponding to the current input
115
        pred = predicted_labels[i]
116
        j = 0
117

118
        while j < len(pred):
119
            #assign the label at index j from the predicted labels pred to the variable cls.
120
            cls = pred[j]
121
            #check if the current label cls is an 'O' (representing no entity). If it is, the code will skip
122
            if cls == 'O': pass
123
            # The approach I am following here is that, during inference our model will make predictions for each subword token. Some single words consist of multiple subword tokens. In the code below, we use a word's first subword token prediction as the label for the entire word. Other alternatives that could have been tried are like averaging all subword predictions or taking `B` labels before `I` labels etc.
124

125
            #If the current label is not 'O', this line replaces 'B' with 'I' in the label. This is done to consider 'B' and 'I' tags as the same, simplifying the task of extracting the entity.
126
            #The reason for replacing 'B' with 'I' in this line is to treat both 'B' and 'I' tags as the same for the purpose of extracting entities. By doing this, we can easily identify and extract the entire entity by looking for a continuous sequence of the same 'I' tag. This simplification makes it easier to find the start and end indices of the entities in the input text.
127
            else: cls = cls.replace('B','I')
128
            # the purpose of below block of code is to find the end index of a continuous entity in the predicted labels.
129
            end = j + 1
130
            #above line is done because we want to start looking for the end of the entity from the next position in the predicted labels.
131

132
            #start a while-loop that iterates through the predicted labels, starting from index end.
133
            # The loop continues as long as the following conditions are met:
134
            # end is less than the length of pred (ensuring we don't exceed the boundaries of the list)
135
            #The label at index end (pred[end]) is equal to the current label cls (indicating that the current entity is still continuing)
136
            #end += 1 - This line increments the value of end by 1 in each iteration of the while-loop. This ensures that the loop progresses through the predicted labels until it finds the end of the continuous entity or reaches the end of the list.
137
            #The significance of this block is to identify the continuous sequence of the same entity label in the predicted labels, which represents a single entity in the input text. By finding the start and end indices of this continuous sequence, we can effectively extract the complete entity from the input text.
138
            while end < len(pred) and pred[end] == cls:
139
                end += 1
140
            # j > 0 for only testing the whole nb with just few samples of the entire .txt data
141
            # Else I will get axis mismatch problem from pandas
142
            # but for full training data its j > 7
143
            if cls != 'O' and cls != '' and end - j > 0:
144
                final_preds.append((idx, cls.replace('I-',''),
145
                                    ' '.join(map(str, list(range(j, end))))))
146
            j = end # at last increment or update the j by pushing it forward to the 'end' position
147
            """
148
            If the current label cls is not an 'O' (representing no entity) + not an empty string + The length of the continuous entity sequence is greater than 7 => Then
149

150
            If the above conditions are met, then append a tuple to the final_preds list with the following elements:
151

152
            idx: The input ID from the original DataFrame.
153

154
            cls.replace('I-', ''): The predicted entity class, with the 'I-' prefix removed. (Note, earlier we have already replaced all 'B' with 'I', so here I will be left with 'I-' only)
155

156
            ' '.join(map(str, list(range(j, end)))): A string of space-separated indices corresponding to the predicted entity in the input, created by converting the list of indices from j to end (exclusive) to a string representation.
157

158
            end - j > 7 This condition is a filter to only include entities that span more than 7 tokens in the input text. Depending on the problem at hand, this threshold may be adjusted or removed to include entities of any length.
159

160
            The significance of this block is to extract the entity information from the predicted labels, filter entities based on the length threshold, and store the final entity predictions in a structured format.
161
            """
162

163
    df_pred = pd.DataFrame(final_preds)
164
    df_pred.columns = ['id','class','predictionstring']
165
    df_pred.head(2)
166
    return df_pred
167

168

169
"""
170
Explanations of the below line
171

172
' '.join(map(str, list(range(j, end))))
173

174

175
range(j, end): This part of the code creates a range of integers starting from j (inclusive) to end (exclusive). j is the starting index of the current entity, and end is the index immediately after the last token of the current entity.
176

177

178
list(range(j, end)): This part converts the range object into a list of integers. The list contains all the indices that correspond to the tokens of the predicted entity in the input text.
179

180

181
map(str, list(range(j, end))): This part applies the str function to each element of the list using the map() function. The purpose of this step is to convert the list of integers into a list of strings. This is necessary because we want to join these indices into a single string in the next step.
182

183
' '.join(map(str, list(range(j, end)))): Finally, this part uses the join() method of the string ' ' (a single space) to concatenate all the string elements in the list, separated by a space. This results in a single string with space-separated indices corresponding to the predicted entity in the input text.
184

185

186
The reason for using this line in the get_predictions() method is to create a compact and human-readable representation of the token indices for the predicted entity in the input text. This string representation is then added to the final_preds list as part of the tuple storing the entity prediction information,
187

188

189
"""
190

191
def validate(model, df_train_org, df_val, dataloader_val, epoch, unique_valid_id_list ):
192
    """ Validates the performance of a given NER model on a validation dataset.
193

194
    This function computes the F1-score for each class and the overall F1-score for the model on the validation dataset. It prints the F1-score per class and the overall F1-score.
195

196
    Args:
197
    model (nn.Module): A pre-trained NER model for validation.
198
    df_train_org (pd.DataFrame): The original training DataFrame, containing both training and validation samples.
199
    df_val (pd.DataFrame): The validation DataFrame, containing a subset of the columns of df_train_org.
200
    dataloader_val (DataLoader): A DataLoader object that provides an iterable over the validation dataset.
201
    epoch (int): The current epoch number in the training process.
202
    unique_valid_id_list (List[int]): A list of unique validation sample IDs. """
203

204
    time_start = time.time()
205

206
    # Put model in eval model
207
    model.eval()
208

209
    df_valid = df_train_org.loc[df_train_org['id'].isin(unique_valid_id_list)]
210

211
    out_of_fold = get_predictions(df_val, dataloader_val, model)
212

213
    f1s = []
214
    classes = out_of_fold['class'].unique()
215

216
    epoch_prefix = f"[Epoch {epoch+1:2d} / {config['epochs']:2d}]"
217
    print(f"{epoch_prefix} Validation F1 scores")
218

219
    f1s_log = {}
220
    for c in classes:
221
        # creates a new DataFrame pred_df by filtering the out_of_fold DataFrame for rows where the 'class' column matches the current class c. The .copy() method is used to create a copy of the filtered data, ensuring that any changes made to pred_df won't affect the original out_of_fold DataFrame. out_of_fold contains the model's predictions for the validation dataset. By creating pred_df, the function can focus on the performance of the model for the specific class c when calculating the F1-score.
222
        pred_df = out_of_fold.loc[out_of_fold['class']==c].copy()
223
        gt_df = df_valid.loc[df_valid['discourse_type']==c].copy()
224
        f1 = compute_macro_f1_score(pred_df, gt_df)
225
        print(f"{epoch_prefix}   * {c:<10}: {f1:4f}")
226
        f1s.append(f1)
227
        f1s_log[f'F1 {c}'] = f1
228

229
    elapsed = time.time() - time_start
230
    print(epoch_prefix)
231
    print(f'{epoch_prefix} Overall Validation F1: {np.mean(f1s):.4f} [{elapsed:.2f} secs]')
232
    print(epoch_prefix)
233
    f1s_log['Overall F1'] = np.mean(f1s)
234
    wandb.log(f1s_log)

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.