llama-index

Форк
0
179 строк · 6.2 Кб
1
# Validates training data and estimates token usage
2
# Copied from https://platform.openai.com/docs/guides/fine-tuning/preparing-your-dataset
3
# Usage:
4
#  python validate_json.py <path_to_jsonl_file>
5

6

7
# We start by importing the required packages
8

9
import json
10
import os
11
import sys
12
from collections import defaultdict
13
from typing import Dict, List
14

15
import numpy as np
16
import tiktoken
17

18

19
def validate_json(data_path: str) -> None:
20
    # Load dataset
21
    with open(data_path) as f:
22
        dataset = [json.loads(line) for line in f]
23

24
    # We can inspect the data quickly by checking the number
25
    # of examples and the first item
26

27
    # Initial dataset stats
28
    print("Num examples:", len(dataset))
29
    print("First example:")
30
    for message in dataset[0]["messages"]:
31
        print(message)
32

33
    # Now that we have a sense of the data, we need to go through all the different
34
    # examples and check to make sure the formatting is correct and matches the Chat
35
    # completions message structure
36

37
    # Format error checks
38
    format_errors: Dict[str, int] = defaultdict(int)
39

40
    for ex in dataset:
41
        if not isinstance(ex, dict):
42
            format_errors["data_type"] += 1
43
            continue
44

45
        messages = ex.get("messages", None)
46
        if not messages:
47
            format_errors["missing_messages_list"] += 1
48
            continue
49

50
        for message in messages:
51
            if "role" not in message or "content" not in message:
52
                format_errors["message_missing_key"] += 1
53

54
            if any(k not in ("role", "content", "name") for k in message):
55
                format_errors["message_unrecognized_key"] += 1
56

57
            if message.get("role", None) not in ("system", "user", "assistant"):
58
                format_errors["unrecognized_role"] += 1
59

60
            content = message.get("content", None)
61
            if not content or not isinstance(content, str):
62
                format_errors["missing_content"] += 1
63

64
        if not any(message.get("role", None) == "assistant" for message in messages):
65
            format_errors["example_missing_assistant_message"] += 1
66

67
    if format_errors:
68
        print("Found errors:")
69
        for k, v in format_errors.items():
70
            print(f"{k}: {v}")
71
    else:
72
        print("No errors found")
73

74
    # Beyond the structure of the message, we also need to ensure that the length does
75
    # not exceed the 4096 token limit.
76

77
    # Token counting functions
78
    encoding = tiktoken.get_encoding("cl100k_base")
79

80
    # not exact!
81
    # simplified from https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
82
    def num_tokens_from_messages(
83
        messages: List[dict], tokens_per_message: int = 3, tokens_per_name: int = 1
84
    ) -> int:
85
        num_tokens = 0
86
        for message in messages:
87
            num_tokens += tokens_per_message
88
            for key, value in message.items():
89
                num_tokens += len(encoding.encode(value))
90
                if key == "name":
91
                    num_tokens += tokens_per_name
92
        num_tokens += 3
93
        return num_tokens
94

95
    def num_assistant_tokens_from_messages(messages: List[dict]) -> int:
96
        num_tokens = 0
97
        for message in messages:
98
            if message["role"] == "assistant":
99
                num_tokens += len(encoding.encode(message["content"]))
100
        return num_tokens
101

102
    def print_distribution(values: list, name: str) -> None:
103
        print(f"\n#### Distribution of {name}:")
104
        print(f"min / max: {min(values)}, {max(values)}")
105
        print(f"mean / median: {np.mean(values)}, {np.median(values)}")
106
        print(f"p5 / p95: {np.quantile(values, 0.1)}, {np.quantile(values, 0.9)}")
107

108
    # Last, we can look at the results of the different formatting operations before
109
    # proceeding with creating a fine-tuning job:
110

111
    # Warnings and tokens counts
112
    n_missing_system = 0
113
    n_missing_user = 0
114
    n_messages = []
115
    convo_lens = []
116
    assistant_message_lens = []
117

118
    for ex in dataset:
119
        messages = ex["messages"]
120
        if not any(message["role"] == "system" for message in messages):
121
            n_missing_system += 1
122
        if not any(message["role"] == "user" for message in messages):
123
            n_missing_user += 1
124
        n_messages.append(len(messages))
125
        convo_lens.append(num_tokens_from_messages(messages))
126
        assistant_message_lens.append(num_assistant_tokens_from_messages(messages))
127

128
    print("Num examples missing system message:", n_missing_system)
129
    print("Num examples missing user message:", n_missing_user)
130
    print_distribution(n_messages, "num_messages_per_example")
131
    print_distribution(convo_lens, "num_total_tokens_per_example")
132
    print_distribution(assistant_message_lens, "num_assistant_tokens_per_example")
133
    n_too_long = sum(length > 4096 for length in convo_lens)
134
    print(
135
        f"\n{n_too_long} examples may be over the 4096 token limit, "
136
        "they will be truncated during fine-tuning"
137
    )
138

139
    # Pricing and default n_epochs estimate
140
    MAX_TOKENS_PER_EXAMPLE = 4096
141

142
    MIN_TARGET_EXAMPLES = 100
143
    MAX_TARGET_EXAMPLES = 25000
144
    TARGET_EPOCHS = 3
145
    MIN_EPOCHS = 1
146
    MAX_EPOCHS = 25
147

148
    n_epochs = TARGET_EPOCHS
149
    n_train_examples = len(dataset)
150
    if n_train_examples * TARGET_EPOCHS < MIN_TARGET_EXAMPLES:
151
        n_epochs = min(MAX_EPOCHS, MIN_TARGET_EXAMPLES // n_train_examples)
152
    elif n_train_examples * TARGET_EPOCHS > MAX_TARGET_EXAMPLES:
153
        n_epochs = max(MIN_EPOCHS, MAX_TARGET_EXAMPLES // n_train_examples)
154

155
    n_billing_tokens_in_dataset = sum(
156
        min(MAX_TOKENS_PER_EXAMPLE, length) for length in convo_lens
157
    )
158
    print(
159
        f"Dataset has ~{n_billing_tokens_in_dataset} tokens that will "
160
        "be charged for during training"
161
    )
162
    print(f"By default, you'll train for {n_epochs} epochs on this dataset")
163
    print(
164
        "By default, you'll be charged for "
165
        f"~{n_epochs * n_billing_tokens_in_dataset} tokens"
166
    )
167

168
    print("As of August 22, 2023, fine-tuning gpt-3.5-turbo is $0.008 / 1K Tokens.")
169
    print(
170
        "This means your total cost for training will be "
171
        f"${n_billing_tokens_in_dataset * 0.008 / 1000} per epoch."
172
    )
173

174

175
if __name__ == "__main__":
176
    data_path = sys.argv[1]
177
    if not os.path.exists(data_path):
178
        raise ValueError(f"Path {data_path} does not exist")
179
    validate_json(data_path)
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.