h2o-llmstudio
89 строк · 2.4 Кб
1import logging
2from typing import Any, Dict
3
4from torch import nn
5from transformers import AutoModelForCausalLM
6
7from llm_studio.src.utils.data_utils import batch_padding
8from llm_studio.src.utils.modeling_utils import create_nlp_backbone, prepare_lora
9
10logger = logging.getLogger(__name__)
11
12
13class Model(nn.Module):
14"""
15Model for causal language modeling problem type.
16"""
17
18def __init__(self, cfg: Any):
19"""
20Args:
21cfg: config with all the hyperparameters
22"""
23
24super(Model, self).__init__()
25
26self.cfg = cfg
27self.backbone, self.backbone_config = create_nlp_backbone(
28cfg, model_class=AutoModelForCausalLM
29)
30
31if cfg.training.lora:
32self.backbone = prepare_lora(cfg, self.backbone)
33
34self.classification_head = nn.Linear(
35self.backbone_config.vocab_size, cfg.dataset.num_classes, bias=False
36)
37
38self.loss_fn = self.cfg.training.loss_class.get(
39self.cfg.training.loss_function
40)(self.cfg)
41
42def forward(
43self,
44batch: Dict,
45padding: bool = True,
46) -> Dict:
47# disable cache if gradient checkpointing is enabled
48if self.cfg.architecture.gradient_checkpointing:
49self.backbone.config.use_cache = False
50
51outputs: Dict = {}
52mask_key = "prompt_attention_mask"
53pad_keys = [
54"prompt_input_ids",
55"prompt_attention_mask",
56"special_tokens_mask",
57"labels",
58]
59
60if padding:
61batch = batch_padding(
62self.cfg,
63batch,
64self.training,
65mask_key=mask_key,
66pad_keys=pad_keys,
67padding_side=self.cfg.tokenizer._padding_side,
68)
69
70output = self.backbone(
71input_ids=batch["prompt_input_ids"],
72attention_mask=batch["prompt_attention_mask"],
73)
74
75output.logits = self.classification_head(output[0][:, -1].float())
76
77if "labels" in batch:
78loss = self.loss_fn(
79output.logits, batch["class_label"].unsqueeze(1).float()
80)
81outputs["loss"] = loss
82
83outputs["logits"] = output.logits
84
85# enable cache again if gradient checkpointing is enabled
86if self.cfg.architecture.gradient_checkpointing:
87self.backbone.config.use_cache = True
88
89return outputs
90