LLM-FineTuning-Large-Language-Models

Форк
0
/
Add-task_specific_custom_layer_to_model.ipynb 
1 строка · 24.8 Кб
1
{"cells":[{"attachments":{},"cell_type":"markdown","metadata":{},"source":["<a href=\"https://www.youtube.com/watch?v=iCL1TmRQ0sk&list=PLxqBkZuBynVQEvXfJpq3smfuKq3AiNW-N&index=19\"><h1 style=\"font-size:250%; font-family:cursive; color:#ff6666;\"><b>Link YouTube Video - Adding a custom task-specific Layer to a HuggingFace Pretrained Model</b></h1></a>\n","\n","[![IMAGE ALT TEXT](https://imgur.com/ZFLzcw9.png)](https://www.youtube.com/watch?v=iCL1TmRQ0sk&list=PLxqBkZuBynVQEvXfJpq3smfuKq3AiNW-N&index=19 \"\")\n","\n","---------------------------"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## First What is BERT?\n","\n","BERT stands for Bidirectional Encoder Representations from Transformers. The name itself gives us several clues to what BERT is all about.\n","\n","BERT architecture consists of several Transformer encoders stacked together. Each Transformer encoder encapsulates two sub-layers: a self-attention layer and a feed-forward layer.\n","\n","### There are two different BERT models:\n","\n","- BERT base, which is a BERT model consists of 12 layers of Transformer encoder, 12 attention heads, 768 hidden size, and 110M parameters.\n","\n","- BERT large, which is a BERT model consists of 24 layers of Transformer encoder,16 attention heads, 1024 hidden size, and 340 parameters.\n","\n","\n","\n","BERT Input and Output\n","BERT model expects a sequence of tokens (words) as an input. In each sequence of tokens, there are two special tokens that BERT would expect as an input:\n","\n","- [CLS]: This is the first token of every sequence, which stands for classification token.\n","- [SEP]: This is the token that makes BERT know which token belongs to which sequence. This special token is mainly important for a next sentence prediction task or question-answering task. If we only have one sequence, then this token will be appended to the end of the sequence.\n","\n","\n","It is also important to note that the maximum size of tokens that can be fed into BERT model is 512. If the tokens in a sequence are less than 512, we can use padding to fill the unused token slots with [PAD] token. If the tokens in a sequence are longer than 512, then we need to do a truncation.\n","\n","And that’s all that BERT expects as input.\n","\n","BERT model then will output an embedding vector of size 768 in each of the tokens. We can use these vectors as an input for different kinds of NLP applications, whether it is text classification, next sentence prediction, Named-Entity-Recognition (NER), or question-answering.\n","\n","\n","------------\n","\n","**For a text classification task**, we focus our attention on the embedding vector output from the special [CLS] token. This means that we’re going to use the embedding vector of size 768 from [CLS] token as an input for our classifier, which then will output a vector of size the number of classes in our classification task.\n","\n","-----------------------\n","\n","![Imgur](https://imgur.com/NpeB9vb.png)\n"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["# The concept of adding a custom task-specific layer to a HF Model\n","\n","![](1.png)\n","\n","\n","When we switch from the pretraining task to the downstream task, we\n","need to replace the last layer of the model with one that is suitable for the task.\n","This last layer is called the model head; it’s the part that is task-specific. \n","\n","The rest of the model is called the body; it includes the token embeddings and\n","transformer layers that are task-agnostic. \n","\n","This structure is reflected in the Transformers code as well: the body of a model is implemented in a class such\n","as BertModel or GPT2Model that returns the hidden states of the last layer. \n","\n","### To get it we do `outputs[0]=last hidden state`\n","\n","The hidden states are passed as inputs to a model head to produce the final output. 🤗 Transformers provides a different model head for each task as long as a model supports the task (i.e., you can’t use DistilBERT for a sequence-to-sequence task like translation).\n","\n","And then we have Task-specific models such as BertForMaskedLM or BertForSequenceClassification - which usees the base model and add the necessary head on top of the hidden states,\n","\n","For example, **`DistilBertForSequenceClassification`** is a base DistilBERT model with a sequence classification head. The sequence classification head is a linear layer on top of the **pooled** outputs.\n","\n","For a **question answering task**, you would use the **`DistilBertForQuestionAnswering`** model head. The question answering head is similar to the sequence classification head except it is a linear layer on top of the **hidden states** output."]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["### To get it we do `outputs[0]=last hidden state`"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["### [Dataset link](https://www.kaggle.com/datasets/rmisra/news-headlines-dataset-for-sarcasm-detection)"]},{"cell_type":"code","execution_count":1,"metadata":{"_cell_guid":"b1076dfc-b9ad-4769-8c92-a6c4dae69d19","_uuid":"8f2839f25d086af736a60e9eeb907d3b93b6e0e5","execution":{"iopub.execute_input":"2022-09-24T20:06:11.979656Z","iopub.status.busy":"2022-09-24T20:06:11.974495Z","iopub.status.idle":"2022-09-24T20:06:14.888273Z","shell.execute_reply":"2022-09-24T20:06:14.887164Z","shell.execute_reply.started":"2022-09-24T20:06:11.979487Z"},"trusted":true},"outputs":[{"name":"stdout","output_type":"stream","text":["../input/news-headlines-dataset-for-sarcasm-detection/Sarcasm_Headlines_Dataset_v2.json\n","../input/news-headlines-dataset-for-sarcasm-detection/Sarcasm_Headlines_Dataset.json\n"]}],"source":["import numpy as np\n","import pandas as pd\n","\n","from datasets import load_dataset,Dataset,DatasetDict\n","from transformers import DataCollatorWithPadding,AutoModelForSequenceClassification, Trainer, TrainingArguments,AutoTokenizer,AutoModel,AutoConfig\n","from transformers.modeling_outputs import TokenClassifierOutput\n","import torch\n","import torch.nn as nn\n","import pandas as pd\n","\n","import os\n","for dirname, _, filenames in os.walk('../input'):\n","    for filename in filenames:\n","        print(os.path.join(dirname, filename))"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["Note: You can use any model in this example (not necessarily a model trained for classification) since we will only use that model’s body and leave the head."]},{"cell_type":"code","execution_count":2,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:14.895494Z","iopub.status.busy":"2022-09-24T20:06:14.894766Z","iopub.status.idle":"2022-09-24T20:06:14.901053Z","shell.execute_reply":"2022-09-24T20:06:14.900060Z","shell.execute_reply.started":"2022-09-24T20:06:14.895453Z"},"trusted":true},"outputs":[],"source":["dataset_v2_path = \"../input/news-headlines-dataset-for-sarcasm-detection/Sarcasm_Headlines_Dataset_v2.json\""]},{"cell_type":"code","execution_count":3,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:14.903543Z","iopub.status.busy":"2022-09-24T20:06:14.902783Z","iopub.status.idle":"2022-09-24T20:06:15.031965Z","shell.execute_reply":"2022-09-24T20:06:15.030843Z","shell.execute_reply.started":"2022-09-24T20:06:14.903504Z"},"trusted":true},"outputs":[{"data":{"text/html":["<div>\n","<style scoped>\n","    .dataframe tbody tr th:only-of-type {\n","        vertical-align: middle;\n","    }\n","\n","    .dataframe tbody tr th {\n","        vertical-align: top;\n","    }\n","\n","    .dataframe thead th {\n","        text-align: right;\n","    }\n","</style>\n","<table border=\"1\" class=\"dataframe\">\n","  <thead>\n","    <tr style=\"text-align: right;\">\n","      <th></th>\n","      <th>is_sarcastic</th>\n","      <th>headline</th>\n","      <th>article_link</th>\n","    </tr>\n","  </thead>\n","  <tbody>\n","    <tr>\n","      <th>0</th>\n","      <td>1</td>\n","      <td>thirtysomething scientists unveil doomsday clo...</td>\n","      <td>https://www.theonion.com/thirtysomething-scien...</td>\n","    </tr>\n","    <tr>\n","      <th>1</th>\n","      <td>0</td>\n","      <td>dem rep. totally nails why congress is falling...</td>\n","      <td>https://www.huffingtonpost.com/entry/donna-edw...</td>\n","    </tr>\n","    <tr>\n","      <th>2</th>\n","      <td>0</td>\n","      <td>eat your veggies: 9 deliciously different recipes</td>\n","      <td>https://www.huffingtonpost.com/entry/eat-your-...</td>\n","    </tr>\n","    <tr>\n","      <th>3</th>\n","      <td>1</td>\n","      <td>inclement weather prevents liar from getting t...</td>\n","      <td>https://local.theonion.com/inclement-weather-p...</td>\n","    </tr>\n","    <tr>\n","      <th>4</th>\n","      <td>1</td>\n","      <td>mother comes pretty close to using word 'strea...</td>\n","      <td>https://www.theonion.com/mother-comes-pretty-c...</td>\n","    </tr>\n","  </tbody>\n","</table>\n","</div>"],"text/plain":["   is_sarcastic                                           headline  \\\n","0             1  thirtysomething scientists unveil doomsday clo...   \n","1             0  dem rep. totally nails why congress is falling...   \n","2             0  eat your veggies: 9 deliciously different recipes   \n","3             1  inclement weather prevents liar from getting t...   \n","4             1  mother comes pretty close to using word 'strea...   \n","\n","                                        article_link  \n","0  https://www.theonion.com/thirtysomething-scien...  \n","1  https://www.huffingtonpost.com/entry/donna-edw...  \n","2  https://www.huffingtonpost.com/entry/eat-your-...  \n","3  https://local.theonion.com/inclement-weather-p...  \n","4  https://www.theonion.com/mother-comes-pretty-c...  "]},"execution_count":3,"metadata":{},"output_type":"execute_result"}],"source":["df = pd.read_json(dataset_v2_path, lines=True)\n","df.head()"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Load Dataset with HF's load_dataset"]},{"cell_type":"code","execution_count":4,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:15.035829Z","iopub.status.busy":"2022-09-24T20:06:15.035438Z","iopub.status.idle":"2022-09-24T20:06:15.439579Z","shell.execute_reply":"2022-09-24T20:06:15.438631Z","shell.execute_reply.started":"2022-09-24T20:06:15.035790Z"},"trusted":true},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"b61a5114d7844ad9a09180df33aa99d2","version_major":2,"version_minor":0},"text/plain":["  0%|          | 0/1 [00:00<?, ?it/s]"]},"metadata":{},"output_type":"display_data"}],"source":["dataset_hf=load_dataset(\"json\", data_files=dataset_v2_path)"]},{"cell_type":"code","execution_count":6,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:15.451197Z","iopub.status.busy":"2022-09-24T20:06:15.450771Z","iopub.status.idle":"2022-09-24T20:06:15.476681Z","shell.execute_reply":"2022-09-24T20:06:15.475871Z","shell.execute_reply.started":"2022-09-24T20:06:15.451161Z"},"trusted":true},"outputs":[],"source":["dataset_hf=dataset_hf.remove_columns(['article_link'])\n","\n","dataset_hf.set_format('pandas')\n","\n","dataset_hf=dataset_hf['train'][:]"]},{"cell_type":"code","execution_count":7,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:15.478234Z","iopub.status.busy":"2022-09-24T20:06:15.477871Z","iopub.status.idle":"2022-09-24T20:06:15.519459Z","shell.execute_reply":"2022-09-24T20:06:15.518633Z","shell.execute_reply.started":"2022-09-24T20:06:15.478199Z"},"trusted":true},"outputs":[{"data":{"text/plain":["DatasetDict({\n","    train: Dataset({\n","        features: ['headline', 'label'],\n","        num_rows: 22802\n","    })\n","    test: Dataset({\n","        features: ['headline', 'label'],\n","        num_rows: 2851\n","    })\n","    valid: Dataset({\n","        features: ['headline', 'label'],\n","        num_rows: 2850\n","    })\n","})"]},"execution_count":7,"metadata":{},"output_type":"execute_result"}],"source":["dataset_hf.drop_duplicates(subset=['headline'],inplace=True)\n","\n","dataset_hf=dataset_hf.reset_index()[['headline','label']]\n","\n","dataset_hf=Dataset.from_pandas(dataset_hf)\n","\n","\n","# Train Test Valid Split\n","train_testvalid = dataset_hf.train_test_split(test_size=0.2,seed=15)\n","\n","\n","test_valid = train_testvalid['test'].train_test_split(test_size=0.5,seed=15)\n","\n","dataset_hf = DatasetDict({\n","    'train': train_testvalid['train'],\n","    'test': test_valid['test'],\n","    'valid': test_valid['train']})\n","\n","dataset_hf\n"]},{"cell_type":"code","execution_count":null,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:15.521428Z","iopub.status.busy":"2022-09-24T20:06:15.521084Z","iopub.status.idle":"2022-09-24T20:06:20.550500Z","shell.execute_reply":"2022-09-24T20:06:20.549520Z","shell.execute_reply.started":"2022-09-24T20:06:15.521393Z"},"trusted":true},"outputs":[],"source":["checkpoint = \"distilbert-base-uncased\"\n","\n","tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n","\n","tokenizer.model_max_len=512"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Vector size \"distilbert-base-uncased\"\n","\n","In the model distilbert-base-uncased, each token is embedded into a vector of size 768. The shape of the output from the base model is \n","\n","### (batch_size, max_sequence_length, embedding_vector_size=768)"]},{"cell_type":"code","execution_count":9,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:20.552528Z","iopub.status.busy":"2022-09-24T20:06:20.552111Z","iopub.status.idle":"2022-09-24T20:06:22.428582Z","shell.execute_reply":"2022-09-24T20:06:22.427554Z","shell.execute_reply.started":"2022-09-24T20:06:20.552490Z"},"trusted":true},"outputs":[{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"4b8494c820054acbaf6a1869d8b8d672","version_major":2,"version_minor":0},"text/plain":["  0%|          | 0/23 [00:00<?, ?ba/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"f306a95666804ac19f2bc7e62361d4b0","version_major":2,"version_minor":0},"text/plain":["  0%|          | 0/3 [00:00<?, ?ba/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"application/vnd.jupyter.widget-view+json":{"model_id":"be6677dbd85c4fb88654eaa3b466df62","version_major":2,"version_minor":0},"text/plain":["  0%|          | 0/3 [00:00<?, ?ba/s]"]},"metadata":{},"output_type":"display_data"},{"data":{"text/plain":["DatasetDict({\n","    train: Dataset({\n","        features: ['headline', 'label', 'input_ids', 'attention_mask'],\n","        num_rows: 22802\n","    })\n","    test: Dataset({\n","        features: ['headline', 'label', 'input_ids', 'attention_mask'],\n","        num_rows: 2851\n","    })\n","    valid: Dataset({\n","        features: ['headline', 'label', 'input_ids', 'attention_mask'],\n","        num_rows: 2850\n","    })\n","})"]},"execution_count":9,"metadata":{},"output_type":"execute_result"}],"source":["def tokenize(batch):\n","  return tokenizer(batch[\"headline\"], truncation=True, max_length=512)\n","\n","tokenized_dataset = dataset_hf.map(tokenize, batched=True)\n","tokenized_dataset"]},{"cell_type":"code","execution_count":10,"metadata":{"execution":{"iopub.execute_input":"2022-09-24T20:06:22.430849Z","iopub.status.busy":"2022-09-24T20:06:22.430161Z","iopub.status.idle":"2022-09-24T20:06:22.437849Z","shell.execute_reply":"2022-09-24T20:06:22.436696Z","shell.execute_reply.started":"2022-09-24T20:06:22.430810Z"},"trusted":true},"outputs":[],"source":["tokenized_dataset.set_format('torch', columns=[\"input_ids\", \"attention_mask\", \"label\"] )\n","\n","data_collator = DataCollatorWithPadding(tokenizer=tokenizer)\n","\n"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["https://huggingface.co/docs/transformers/main_classes/data_collator\n","\n","\n","Data collators are objects that will form a batch by using a list of dataset elements as input. These elements are of the same type as the elements of train_dataset or eval_dataset.\n","\n","To be able to build batches, data collators may apply some processing (like padding). Some of them (like DataCollatorForLanguageModeling) also apply some random data augmentation (like random masking) on the formed batch.\n","\n","data_collator automatically pads the model inputs in a batch to the length of the longest example. This bypasses the need to set a global maximum sequence length, and in practice leads to faster training since we perform fewer redundant computations on the padded tokens and attention masks.\n","\n","\n","------------"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## We construct `MyTaskSpecificCustomModel` class that inherits from the nn.Module."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["class MyTaskSpecificCustomModel(nn.Module):\n","    \"\"\"\n","    A task-specific custom transformer model. This model loads a pre-trained transformer model and adds a new dropout \n","    and linear layer at the end for fine-tuning and prediction on specific tasks.\n","    \"\"\"\n","    def __init__(self, checkpoint, num_labels ):\n","        \"\"\"\n","        Args:\n","            checkpoint (str): The name of the pre-trained model or path to the model weights.\n","            num_labels (int): The number of output labels in the final classification layer.\n","        \"\"\"\n","        super(MyTaskSpecificCustomModel, self).__init__()\n","        self.num_labels = num_labels\n","        \n","        self.model = model = AutoModel.from_pretrained(checkpoint, config = AutoConfig.from_pretrained(checkpoint, \n","                                                                                                       output_attention = True, \n","                                                                                                       output_hidden_state = True ) )\n","        # New Layer\n","        self.dropout = nn.Dropout(0.1)\n","        self.classifier = nn.Linear(768, num_labels )\n","        \n","    def forward(self, input_ids = None, attention_mask=None, labels = None ):\n","        \"\"\"\n","        Forward pass for the model.\n","        \n","        Args:\n","            input_ids (torch.Tensor, optional): Tensor of input IDs. Defaults to None.\n","            attention_mask (torch.Tensor, optional): Tensor for attention masks. Defaults to None.\n","            labels (torch.Tensor, optional): Tensor for labels. Defaults to None.\n","            \n","        Returns:\n","            TokenClassifierOutput: A named tuple with the following fields:\n","            - loss (torch.FloatTensor of shape (1,), optional, returned when label_ids is provided) – Classification loss.\n","            - logits (torch.FloatTensor of shape (batch_size, num_labels)) – Classification scores before SoftMax.\n","            - hidden_states (tuple(torch.FloatTensor), optional, returned when output_hidden_states=True is passed or when config.output_hidden_states=True) – Tuple of torch.FloatTensor (one for the output of the embeddings + one for the output of each layer) of shape (batch_size, sequence_length, hidden_size).\n","            - attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) – Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).\n","        \"\"\"\n","        outputs = self.model(input_ids = input_ids, attention_mask = attention_mask  )\n","        \n","        last_hidden_state = outputs[0]\n","        \n","        sequence_outputs = self.dropouts(last_hidden_state)\n","        \n","        logits = self.classifier(sequence_outputs[:, 0, : ].view(-1, 768 ))\n","        \n","        loss = None\n","        loss = None\n","        if labels is not None:\n","            loss_func = nn.CrossEntropyLoss()\n","            loss = loss_func(logits.view(-1, self.num_labels), labels.view(-1))\n","            \n","            return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)\n","        \n","    "]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## attention_mask\n","\n","From Doc - This argument indicates to the model which tokens should be attended to, and which should not.\n","\n","### If the attention_mask is 0, the token id is ignored. For instance if a sequence is padded to adjust the sequence length, the padded words should be ignored hence their attention_mask are 0.\n","\n","--------------\n","\n","### torch.nn.Linear(in_features, out_features, bias=True)\n","\n","Parameters\n","in_features – size of each input sample\n","out_features – size of each output sample\n","\n","## Making sense of `nn.Linear`\n","\n","#### In your Neural Network, the `self.hidden = nn.Linear(784, 256)` defines a _hidden_ (meaning that it is in between of the input and output layers), _fully connected linear layer_, which takes input `x` of shape `(batch_size, 784)`, where batch size is the number of inputs (each of size 784) which are passed to the network at once (as a single tensor), and transforms it by the linear equation `y = x*W^T + b` into a tensor `y` of shape `(batch_size, 256)`.\n","\n","------------------------"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["# Create PyTorch DataLoader"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from torch.utils.data import DataLoader\n","\n","train_dataloader = DataLoader(\n","    tokenized_dataset['train'], shuffle = True, batch_size = 32, collate_fn = data_collator\n",")\n","\n","eval_dataloader = DataLoader(\n","    tokenized_dataset['valid'], shuffle = True, collate_fn = data_collator\n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","\n","model_task_specific = MyTaskSpecificCustomModel(checkpoint=checkpoint, num_labels=2 ).to(device)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from transformers import AdamW, get_scheduler\n","\n","optimizer = AdamW(model_task_specific.parameters(), lr = 5e-5 )\n","\n","num_epoch = 3\n","\n","num_training_steps = num_epoch * len(train_dataloader)\n","\n","lr_scheduler = get_scheduler(\n","    'linear',\n","    optimizer = optimizer,\n","    num_warmup_step = 0,\n","    num_training_steps = num_training_steps,\n","    \n",")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from datasets import load_metric\n","metric = load_metric(\"f1\")"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["# Training"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from tqdm.auto import tqdm\n","\n","progress_bar_train = tqdm(range(num_training_steps))\n","progress_bar_eval = tqdm(range(num_epoch * len(eval_dataloader) ))\n","\n","\n","from epoch in range(num_epoch):\n","    model_task_specific.train()\n","    for batch in train_dataloader:\n","        batch = { k: v.to(device) for k, v in batch.items() }\n","        outputs = model_task_specific(**batch)\n","        loss = outputs.loss\n","        loss.backward()\n","        \n","        optimizer.step()\n","        lr_scheduler.step()\n","        optimizer.zero_grad()\n","        progress_bar_train.update(1)\n","        \n","    model_task_specific.eval()\n","    for batch in eval_dataloader:\n","        batch = { k: v.to(device) for k, v in batch.items() }\n","        with torch.no_grad():\n","            outputs = model_task_specific(**batch)\n","            \n","        logits = outputs.logits\n","        predictions = torch.argmax(logits, dim = -1 )\n","        metric.add_batch(predictions = predictions, references = batch['labels'] )\n","        progress_bar_eval.update(1)\n","        \n","    print(metric.compute()) \n","       \n"]},{"attachments":{},"cell_type":"markdown","metadata":{},"source":["## Post Training Evaluation"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["model_task_specific.eval()\n","\n","test_dataloader = DataLoader(\n","    tokenized_dataset['test'], batch_size = 32, collate_fn = data_collator\n",")\n","\n","\n","for batch in test_dataloader:\n","    batch = { k: v.to(device) for k, v in batch.items() }\n","    with torch.no_grad():\n","        outputs = model_task_specific(**batch)\n","        \n","    logits = outputs.logits\n","    predictions = torch.argmax(logits, dim = -1)\n","    metric.add_batch(predictons = predictions, references=batch['labels'] )\n","    \n","metric.compute()  \n","    "]}],"metadata":{"kernelspec":{"display_name":"Python 3.9.14 64-bit","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.9.14"},"vscode":{"interpreter":{"hash":"36cf16204b8548560b1c020c4e8fb5b57f0e4c58016f52f2d4be01e192833930"}}},"nbformat":4,"nbformat_minor":4}
2

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.