simpletransformers

Форк
0
75 строк · 1.8 Кб
1
import logging
2

3
import pandas as pd
4
import sklearn
5
import wandb
6

7
from simpletransformers.classification import ClassificationArgs, ClassificationModel
8

9
sweep_config = {
10
    "method": "bayes",  # grid, random
11
    "metric": {"name": "train_loss", "goal": "minimize"},
12
    "parameters": {
13
        "num_train_epochs": {"values": [2, 3, 5]},
14
        "learning_rate": {"min": 5e-5, "max": 4e-4},
15
    },
16
}
17

18
sweep_id = wandb.sweep(sweep_config, project="Simple Sweep")
19

20
logging.basicConfig(level=logging.INFO)
21
transformers_logger = logging.getLogger("transformers")
22
transformers_logger.setLevel(logging.WARNING)
23

24
# Preparing train data
25
train_data = [
26
    ["Aragorn was the heir of Isildur", "true"],
27
    ["Frodo was the heir of Isildur", "false"],
28
]
29
train_df = pd.DataFrame(train_data)
30
train_df.columns = ["text", "labels"]
31

32
# Preparing eval data
33
eval_data = [
34
    ["Theoden was the king of Rohan", "true"],
35
    ["Merry was the king of Rohan", "false"],
36
]
37
eval_df = pd.DataFrame(eval_data)
38
eval_df.columns = ["text", "labels"]
39

40
model_args = ClassificationArgs()
41
model_args.reprocess_input_data = True
42
model_args.overwrite_output_dir = True
43
model_args.evaluate_during_training = True
44
model_args.manual_seed = 4
45
model_args.use_multiprocessing = True
46
model_args.train_batch_size = 16
47
model_args.eval_batch_size = 8
48
model_args.labels_list = ["true", "false"]
49
model_args.wandb_project = "Simple Sweep"
50

51

52
def train():
53
    # Initialize a new wandb run
54
    wandb.init()
55

56
    # Create a TransformerModel
57
    model = ClassificationModel(
58
        "roberta",
59
        "roberta-base",
60
        use_cuda=True,
61
        args=model_args,
62
        sweep_config=wandb.config,
63
    )
64

65
    # Train the model
66
    model.train_model(train_df, eval_df=eval_df)
67

68
    # Evaluate the model
69
    model.eval_model(eval_df)
70

71
    # Sync wandb
72
    wandb.join()
73

74

75
wandb.agent(sweep_id, train)
76

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.