h2o-llmstudio

Форк
0
/
text_dpo_modeling_plots.py 
142 строки · 5.2 Кб
1
import hashlib
2
import os
3
from typing import Any, Dict, List
4

5
import pandas as pd
6

7
from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains
8
from llm_studio.src.datasets.text_utils import get_tokenizer
9
from llm_studio.src.plots.text_causal_language_modeling_plots import (
10
    create_batch_prediction_df,
11
    plot_validation_predictions,
12
)
13
from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels
14
from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization
15
from llm_studio.src.utils.utils import PatchedAttribute
16

17

18
class Plots:
19
    @classmethod
20
    def plot_batch(cls, batch, cfg) -> PlotData:
21
        tokenizer = get_tokenizer(cfg)
22
        df = create_batch_prediction_df(
23
            batch,
24
            tokenizer,
25
            ids_for_tokenized_text="chosen_input_ids",
26
            labels_column="chosen_labels",
27
        )
28
        path = os.path.join(cfg.output_directory, "batch_viz.parquet")
29
        df.to_parquet(path)
30
        return PlotData(path, encoding="df")
31

32
    @classmethod
33
    def plot_data(cls, cfg) -> PlotData:
34
        """
35
        Plots the data in a scrollable table.
36
        We limit the number of rows to max 600 to avoid rendering issues in Wave.
37
        As the data visualization is instantiated on every page load, we cache the
38
        data visualization in a parquet file.
39
        """
40
        config_id = (
41
            str(cfg.dataset.train_dataframe)
42
            + str(cfg.dataset.system_column)
43
            + str(cfg.dataset.prompt_column)
44
            + str(cfg.dataset.answer_column)
45
            + str(cfg.dataset.rejected_answer_column)
46
            + str(cfg.dataset.parent_id_column)
47
        )
48
        config_hash = hashlib.md5(config_id.encode()).hexdigest()
49
        path = os.path.join(
50
            os.path.dirname(cfg.dataset.train_dataframe),
51
            f"__meta_info__{config_hash}_data_viz.parquet",
52
        )
53
        if os.path.exists(path):
54
            return PlotData(path, encoding="df")
55

56
        df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
57

58
        conversations_chosen = get_conversation_chains(
59
            df, cfg, limit_chained_samples=True
60
        )
61
        with PatchedAttribute(
62
            cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column
63
        ):
64
            conversations_rejected = get_conversation_chains(
65
                df, cfg, limit_chained_samples=True
66
            )
67

68
        # Limit to max 15 prompt-conversation-answer rounds
69
        max_conversation_length = min(
70
            max(
71
                [len(conversation["prompts"]) for conversation in conversations_chosen]
72
            ),
73
            15,
74
        )
75

76
        conversations_to_display: List = []
77
        for conversation_length in range(1, max_conversation_length + 1):
78
            conversations_to_display += [
79
                (conversation_chosen, conversations_rejected)
80
                for conversation_chosen, conversations_rejected in zip(
81
                    conversations_chosen, conversations_rejected
82
                )
83
                if len(conversation_chosen["prompts"]) == conversation_length
84
            ][:5]
85

86
        # Convert into a scrollable table by transposing the dataframe
87
        df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"])
88

89
        i = 0
90
        for sample_number, (conversation_chosen, conversations_rejected) in enumerate(
91
            conversations_to_display
92
        ):
93
            if conversation_chosen["systems"][0] != "":
94
                df_transposed.loc[i] = [
95
                    sample_number,
96
                    "System",
97
                    conversation_chosen["systems"][0],
98
                ]
99
                i += 1
100
            for prompt, answer_chosen, answer_rejected in zip(
101
                conversation_chosen["prompts"],
102
                conversation_chosen["answers"],
103
                conversations_rejected["answers"],  # type: ignore
104
            ):
105
                df_transposed.loc[i] = [
106
                    sample_number,
107
                    "Prompt",
108
                    prompt,
109
                ]
110
                i += 1
111
                if answer_chosen == answer_rejected:
112
                    df_transposed.loc[i] = [
113
                        sample_number,
114
                        "Answer",
115
                        answer_chosen,
116
                    ]
117
                    i += 1
118
                else:
119
                    df_transposed.loc[i] = [
120
                        sample_number,
121
                        "Answer Chosen",
122
                        answer_chosen,
123
                    ]
124
                    i += 1
125
                    df_transposed.loc[i] = [
126
                        sample_number,
127
                        "Answer Rejected",
128
                        answer_rejected,
129
                    ]
130
                    i += 1
131

132
        df_transposed["Content"] = df_transposed["Content"].apply(
133
            format_for_markdown_visualization
134
        )
135
        df_transposed.to_parquet(path)
136
        return PlotData(path, encoding="df")
137

138
    @classmethod
139
    def plot_validation_predictions(
140
        cls, val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str
141
    ) -> PlotData:
142
        return plot_validation_predictions(val_outputs, cfg, val_df, mode)
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.