h2o-llmstudio
117 строк · 3.2 Кб
1import logging2from typing import Any, Dict3
4import torch.nn as nn5from transformers import AutoModelForSeq2SeqLM6
7from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity8from llm_studio.src.utils.data_utils import batch_padding9from llm_studio.src.utils.modeling_utils import (10create_nlp_backbone,11generate,12prepare_lora,13)
14
15logger = logging.getLogger(__name__)16
17
18class Model(nn.Module):19"""20Model for causal language modeling problem type.
21"""
22
23def __init__(self, cfg: Any):24"""25Args:
26cfg: config with all the hyperparameters
27"""
28
29super(Model, self).__init__()30
31self.cfg = cfg32self.backbone, self.backbone_config = create_nlp_backbone(33cfg, model_class=AutoModelForSeq2SeqLM34)35
36if cfg.training.lora:37self.backbone = prepare_lora(cfg, self.backbone)38
39self.loss_fn = self.cfg.training.loss_class.get(40self.cfg.training.loss_function41)(self.cfg)42
43if self.cfg.prediction.metric == "Perplexity":44self.perplexity = Perplexity(self.cfg, reduce=False)45
46def generate(self, batch: Dict, cfg: Any, streamer=None):47return generate(48backbone=self.backbone,49batch=batch,50cfg=cfg,51streamer=streamer,52remove_prompt=False,53)54
55def forward(56self,57batch: Dict,58padding: bool = True,59) -> Dict:60# disable cache if gradient checkpointing is enabled61if self.cfg.architecture.gradient_checkpointing:62self.backbone.config.use_cache = False63
64outputs: Dict = {}65kwargs: Dict = {}66
67if padding:68mask_key = "prompt_attention_mask"69pad_keys = [70"prompt_input_ids",71"prompt_attention_mask",72]73
74batch = batch_padding(75self.cfg,76batch,77self.training,78mask_key=mask_key,79pad_keys=pad_keys,80padding_side=self.cfg.tokenizer._padding_side,81)82
83mask_key = "answer_attention_mask"84pad_keys = [85"answer_input_ids",86"answer_attention_mask",87]88
89batch = batch_padding(90self.cfg,91batch,92self.training,93mask_key=mask_key,94pad_keys=pad_keys,95padding_side="right",96)97
98labels = batch["answer_input_ids"]99labels[batch["answer_attention_mask"] == 0] = -100100
101output = self.backbone(102input_ids=batch["prompt_input_ids"],103attention_mask=batch["prompt_attention_mask"],104labels=labels,105**kwargs,106)107
108outputs["loss"] = output.loss109
110if not self.training and self.cfg.prediction.metric == "Perplexity":111outputs["perplexity"] = self.perplexity(output.logits, labels)112
113# enable cache again if gradient checkpointing is enabled114if self.cfg.architecture.gradient_checkpointing:115self.backbone.config.use_cache = True116
117return outputs118