h2o-llmstudio

Форк
0
/
text_sequence_to_sequence_modeling_model.py 
117 строк · 3.2 Кб
1
import logging
2
from typing import Any, Dict
3

4
import torch.nn as nn
5
from transformers import AutoModelForSeq2SeqLM
6

7
from llm_studio.src.metrics.text_causal_language_modeling_metrics import Perplexity
8
from llm_studio.src.utils.data_utils import batch_padding
9
from llm_studio.src.utils.modeling_utils import (
10
    create_nlp_backbone,
11
    generate,
12
    prepare_lora,
13
)
14

15
logger = logging.getLogger(__name__)
16

17

18
class Model(nn.Module):
19
    """
20
    Model for causal language modeling problem type.
21
    """
22

23
    def __init__(self, cfg: Any):
24
        """
25
        Args:
26
            cfg: config with all the hyperparameters
27
        """
28

29
        super(Model, self).__init__()
30

31
        self.cfg = cfg
32
        self.backbone, self.backbone_config = create_nlp_backbone(
33
            cfg, model_class=AutoModelForSeq2SeqLM
34
        )
35

36
        if cfg.training.lora:
37
            self.backbone = prepare_lora(cfg, self.backbone)
38

39
        self.loss_fn = self.cfg.training.loss_class.get(
40
            self.cfg.training.loss_function
41
        )(self.cfg)
42

43
        if self.cfg.prediction.metric == "Perplexity":
44
            self.perplexity = Perplexity(self.cfg, reduce=False)
45

46
    def generate(self, batch: Dict, cfg: Any, streamer=None):
47
        return generate(
48
            backbone=self.backbone,
49
            batch=batch,
50
            cfg=cfg,
51
            streamer=streamer,
52
            remove_prompt=False,
53
        )
54

55
    def forward(
56
        self,
57
        batch: Dict,
58
        padding: bool = True,
59
    ) -> Dict:
60
        # disable cache if gradient checkpointing is enabled
61
        if self.cfg.architecture.gradient_checkpointing:
62
            self.backbone.config.use_cache = False
63

64
        outputs: Dict = {}
65
        kwargs: Dict = {}
66

67
        if padding:
68
            mask_key = "prompt_attention_mask"
69
            pad_keys = [
70
                "prompt_input_ids",
71
                "prompt_attention_mask",
72
            ]
73

74
            batch = batch_padding(
75
                self.cfg,
76
                batch,
77
                self.training,
78
                mask_key=mask_key,
79
                pad_keys=pad_keys,
80
                padding_side=self.cfg.tokenizer._padding_side,
81
            )
82

83
            mask_key = "answer_attention_mask"
84
            pad_keys = [
85
                "answer_input_ids",
86
                "answer_attention_mask",
87
            ]
88

89
            batch = batch_padding(
90
                self.cfg,
91
                batch,
92
                self.training,
93
                mask_key=mask_key,
94
                pad_keys=pad_keys,
95
                padding_side="right",
96
            )
97

98
        labels = batch["answer_input_ids"]
99
        labels[batch["answer_attention_mask"] == 0] = -100
100

101
        output = self.backbone(
102
            input_ids=batch["prompt_input_ids"],
103
            attention_mask=batch["prompt_attention_mask"],
104
            labels=labels,
105
            **kwargs,
106
        )
107

108
        outputs["loss"] = output.loss
109

110
        if not self.training and self.cfg.prediction.metric == "Perplexity":
111
            outputs["perplexity"] = self.perplexity(output.logits, labels)
112

113
        # enable cache again if gradient checkpointing is enabled
114
        if self.cfg.architecture.gradient_checkpointing:
115
            self.backbone.config.use_cache = True
116

117
        return outputs
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.