llama-index

Форк
0
182 строки · 6.4 Кб
1
# Validates training data and estimates token usage
2
# Copied from https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
3
# Usage:
4
#  python validate_json.py <path_to_jsonl_file>
5

6

7
# We start by importing the required packages
8

9
import json
10
import os
11
import sys
12
from collections import defaultdict
13
from typing import Dict, List
14

15
import numpy as np
16
import tiktoken
17

18

19
def validate_json(data_path: str) -> None:
20
    # Load dataset
21
    with open(data_path) as f:
22
        dataset = [json.loads(line) for line in f]
23

24
    # We can inspect the data quickly by checking the number
25
    # of examples and the first item
26

27
    # Initial dataset stats
28
    print("Num examples:", len(dataset))
29
    print("First example:")
30
    for message in dataset[0]["messages"]:
31
        print(message)
32

33
    # Now that we have a sense of the data, we need to go through all the different
34
    # examples and check to make sure the formatting is correct and matches the Chat
35
    # completions message structure
36

37
    # Format error checks
38
    format_errors: Dict[str, int] = defaultdict(int)
39

40
    for ex in dataset:
41
        if not isinstance(ex, dict):
42
            format_errors["data_type"] += 1
43
            continue
44

45
        messages = ex.get("messages", None)
46
        if not messages:
47
            format_errors["missing_messages_list"] += 1
48
            continue
49

50
        for message in messages:
51
            if "role" not in message or "content" not in message:
52
                format_errors["message_missing_key"] += 1
53

54
            if any(k not in ("role", "content", "name") for k in message):
55
                format_errors["message_unrecognized_key"] += 1
56

57
            if message.get("role", None) not in ("system", "user", "assistant"):
58
                format_errors["unrecognized_role"] += 1
59

60
            content = message.get("content", None)
61
            if not content or not isinstance(content, str):
62
                format_errors["missing_content"] += 1
63

64
        if not any(message.get("role", None) == "assistant" for message in messages):
65
            format_errors["example_missing_assistant_message"] += 1
66

67
    if format_errors:
68
        print("Found errors:")
69
        for k, v in format_errors.items():
70
            print(f"{k}: {v}")
71
    else:
72
        print("No errors found")
73

74
    # Beyond the structure of the message, we also need to ensure that the length does
75
    # not exceed the 4096 token limit.
76

77
    # Token counting functions
78
    encoding = tiktoken.get_encoding("cl100k_base")
79

80
    # not exact!
81
    # simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
82
    def num_tokens_from_messages(
83
        messages: List[dict], tokens_per_message: int = 3, tokens_per_name: int = 1
84
    ) -> int:
85
        num_tokens = 0
86
        for message in messages:
87
            num_tokens += tokens_per_message
88
            for key, value in message.items():
89
                # NOTE: try to count tokens in function calling (not in cookbook)
90
                if key == "function_call":
91
                    value = str(value)
92
                num_tokens += len(encoding.encode(value))
93
                if key == "name":
94
                    num_tokens += tokens_per_name
95
        num_tokens += 3
96
        return num_tokens
97

98
    def num_assistant_tokens_from_messages(messages: List[dict]) -> int:
99
        num_tokens = 0
100
        for message in messages:
101
            if message["role"] == "assistant":
102
                num_tokens += len(encoding.encode(message["content"]))
103
        return num_tokens
104

105
    def print_distribution(values: list, name: str) -> None:
106
        print(f"\n#### Distribution of {name}:")
107
        print(f"min / max: {min(values)}, {max(values)}")
108
        print(f"mean / median: {np.mean(values)}, {np.median(values)}")
109
        print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
110

111
    # Last, we can look at the results of the different formatting operations before
112
    # proceeding with creating a fine-tuning job:
113

114
    # Warnings and tokens counts
115
    n_missing_system = 0
116
    n_missing_user = 0
117
    n_messages = []
118
    convo_lens = []
119
    assistant_message_lens = []
120

121
    for ex in dataset:
122
        messages = ex["messages"]
123
        if not any(message["role"] == "system" for message in messages):
124
            n_missing_system += 1
125
        if not any(message["role"] == "user" for message in messages):
126
            n_missing_user += 1
127
        n_messages.append(len(messages))
128
        convo_lens.append(num_tokens_from_messages(messages))
129
        assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
130

131
    print("Num examples missing system message:", n_missing_system)
132
    print("Num examples missing user message:", n_missing_user)
133
    print_distribution(n_messages, "num_messages_per_example")
134
    print_distribution(convo_lens, "num_total_tokens_per_example")
135
    print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
136
    n_too_long = sum(length > 4096 for length in convo_lens)
137
    print(
138
        f"\n{n_too_long} examples may be over the 4096 token limit, "
139
        "they will be truncated during fine-tuning"
140
    )
141

142
    # Pricing and default n_epochs estimate
143
    MAX_TOKENS_PER_EXAMPLE = 4096
144

145
    MIN_TARGET_EXAMPLES = 100
146
    MAX_TARGET_EXAMPLES = 25000
147
    TARGET_EPOCHS = 3
148
    MIN_EPOCHS = 1
149
    MAX_EPOCHS = 25
150

151
    n_epochs = TARGET_EPOCHS
152
    n_train_examples = len(dataset)
153
    if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
154
        n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
155
    elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
156
        n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
157

158
    n_billing_tokens_in_dataset = sum(
159
        min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
160
    )
161
    print(
162
        f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will "
163
        "be charged for during training"
164
    )
165
    print(f"By default, you'll train for {n_epochs} epochs on this dataset")
166
    print(
167
        "By default, you'll be charged for "
168
        f"~{n_epochs * n_billing_tokens_in_dataset} tokens"
169
    )
170

171
    print("As of August 22, 2023, fine-tuning gpt-3.5-turbo is $0.008 / 1K Tokens.")
172
    print(
173
        "This means your total cost for training will be "
174
        f"${n_billing_tokens_in_dataset * 0.008 / 1000} per epoch."
175
    )
176

177

178
if __name__ == "__main__":
179
    data_path = sys.argv[1]
180
    if not os.path.exists(data_path):
181
        raise ValueError(f"Path {data_path} does not exist")
182
    validate_json(data_path)
183

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.