simpletransformers
75 строк · 1.8 Кб
1import logging
2
3import pandas as pd
4import sklearn
5import wandb
6
7from simpletransformers.classification import ClassificationArgs, ClassificationModel
8
9sweep_config = {
10"method": "bayes", # grid, random
11"metric": {"name": "train_loss", "goal": "minimize"},
12"parameters": {
13"num_train_epochs": {"values": [2, 3, 5]},
14"learning_rate": {"min": 5e-5, "max": 4e-4},
15},
16}
17
18sweep_id = wandb.sweep(sweep_config, project="Simple Sweep")
19
20logging.basicConfig(level=logging.INFO)
21transformers_logger = logging.getLogger("transformers")
22transformers_logger.setLevel(logging.WARNING)
23
24# Preparing train data
25train_data = [
26["Aragorn was the heir of Isildur", "true"],
27["Frodo was the heir of Isildur", "false"],
28]
29train_df = pd.DataFrame(train_data)
30train_df.columns = ["text", "labels"]
31
32# Preparing eval data
33eval_data = [
34["Theoden was the king of Rohan", "true"],
35["Merry was the king of Rohan", "false"],
36]
37eval_df = pd.DataFrame(eval_data)
38eval_df.columns = ["text", "labels"]
39
40model_args = ClassificationArgs()
41model_args.reprocess_input_data = True
42model_args.overwrite_output_dir = True
43model_args.evaluate_during_training = True
44model_args.manual_seed = 4
45model_args.use_multiprocessing = True
46model_args.train_batch_size = 16
47model_args.eval_batch_size = 8
48model_args.labels_list = ["true", "false"]
49model_args.wandb_project = "Simple Sweep"
50
51
52def train():
53# Initialize a new wandb run
54wandb.init()
55
56# Create a TransformerModel
57model = ClassificationModel(
58"roberta",
59"roberta-base",
60use_cuda=True,
61args=model_args,
62sweep_config=wandb.config,
63)
64
65# Train the model
66model.train_model(train_df, eval_df=eval_df)
67
68# Evaluate the model
69model.eval_model(eval_df)
70
71# Sync wandb
72wandb.join()
73
74
75wandb.agent(sweep_id, train)
76