h2o-llmstudio
142 строки · 5.2 Кб
1import hashlib
2import os
3from typing import Any, Dict, List
4
5import pandas as pd
6
7from llm_studio.src.datasets.conversation_chain_handler import get_conversation_chains
8from llm_studio.src.datasets.text_utils import get_tokenizer
9from llm_studio.src.plots.text_causal_language_modeling_plots import (
10create_batch_prediction_df,
11plot_validation_predictions,
12)
13from llm_studio.src.utils.data_utils import read_dataframe_drop_missing_labels
14from llm_studio.src.utils.plot_utils import PlotData, format_for_markdown_visualization
15from llm_studio.src.utils.utils import PatchedAttribute
16
17
18class Plots:
19@classmethod
20def plot_batch(cls, batch, cfg) -> PlotData:
21tokenizer = get_tokenizer(cfg)
22df = create_batch_prediction_df(
23batch,
24tokenizer,
25ids_for_tokenized_text="chosen_input_ids",
26labels_column="chosen_labels",
27)
28path = os.path.join(cfg.output_directory, "batch_viz.parquet")
29df.to_parquet(path)
30return PlotData(path, encoding="df")
31
32@classmethod
33def plot_data(cls, cfg) -> PlotData:
34"""
35Plots the data in a scrollable table.
36We limit the number of rows to max 600 to avoid rendering issues in Wave.
37As the data visualization is instantiated on every page load, we cache the
38data visualization in a parquet file.
39"""
40config_id = (
41str(cfg.dataset.train_dataframe)
42+ str(cfg.dataset.system_column)
43+ str(cfg.dataset.prompt_column)
44+ str(cfg.dataset.answer_column)
45+ str(cfg.dataset.rejected_answer_column)
46+ str(cfg.dataset.parent_id_column)
47)
48config_hash = hashlib.md5(config_id.encode()).hexdigest()
49path = os.path.join(
50os.path.dirname(cfg.dataset.train_dataframe),
51f"__meta_info__{config_hash}_data_viz.parquet",
52)
53if os.path.exists(path):
54return PlotData(path, encoding="df")
55
56df = read_dataframe_drop_missing_labels(cfg.dataset.train_dataframe, cfg)
57
58conversations_chosen = get_conversation_chains(
59df, cfg, limit_chained_samples=True
60)
61with PatchedAttribute(
62cfg.dataset, "answer_column", cfg.dataset.rejected_answer_column
63):
64conversations_rejected = get_conversation_chains(
65df, cfg, limit_chained_samples=True
66)
67
68# Limit to max 15 prompt-conversation-answer rounds
69max_conversation_length = min(
70max(
71[len(conversation["prompts"]) for conversation in conversations_chosen]
72),
7315,
74)
75
76conversations_to_display: List = []
77for conversation_length in range(1, max_conversation_length + 1):
78conversations_to_display += [
79(conversation_chosen, conversations_rejected)
80for conversation_chosen, conversations_rejected in zip(
81conversations_chosen, conversations_rejected
82)
83if len(conversation_chosen["prompts"]) == conversation_length
84][:5]
85
86# Convert into a scrollable table by transposing the dataframe
87df_transposed = pd.DataFrame(columns=["Sample Number", "Field", "Content"])
88
89i = 0
90for sample_number, (conversation_chosen, conversations_rejected) in enumerate(
91conversations_to_display
92):
93if conversation_chosen["systems"][0] != "":
94df_transposed.loc[i] = [
95sample_number,
96"System",
97conversation_chosen["systems"][0],
98]
99i += 1
100for prompt, answer_chosen, answer_rejected in zip(
101conversation_chosen["prompts"],
102conversation_chosen["answers"],
103conversations_rejected["answers"], # type: ignore
104):
105df_transposed.loc[i] = [
106sample_number,
107"Prompt",
108prompt,
109]
110i += 1
111if answer_chosen == answer_rejected:
112df_transposed.loc[i] = [
113sample_number,
114"Answer",
115answer_chosen,
116]
117i += 1
118else:
119df_transposed.loc[i] = [
120sample_number,
121"Answer Chosen",
122answer_chosen,
123]
124i += 1
125df_transposed.loc[i] = [
126sample_number,
127"Answer Rejected",
128answer_rejected,
129]
130i += 1
131
132df_transposed["Content"] = df_transposed["Content"].apply(
133format_for_markdown_visualization
134)
135df_transposed.to_parquet(path)
136return PlotData(path, encoding="df")
137
138@classmethod
139def plot_validation_predictions(
140cls, val_outputs: Dict, cfg: Any, val_df: pd.DataFrame, mode: str
141) -> PlotData:
142return plot_validation_predictions(val_outputs, cfg, val_df, mode)
143