wandb
42 строки · 1.3 Кб
1import torch
2import torch.nn as nn
3import torch.nn.functional as F # noqa: N812
4import wandb
5from tensorboardX import SummaryWriter
6
7
8def main():
9wandb.init(tensorboard=True)
10
11class ConvNet(nn.Module):
12def __init__(self):
13super().__init__()
14self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
15self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
16self.conv2_drop = nn.Dropout2d()
17self.fc1 = nn.Linear(320, 50)
18self.fc2 = nn.Linear(50, 10)
19
20def forward(self, x):
21x = F.relu(F.max_pool2d(self.conv1(x), 2))
22x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
23x = x.view(-1, 320)
24x = F.relu(self.fc1(x))
25x = F.dropout(x, training=self.training)
26x = self.fc2(x)
27return F.log_softmax(x, dim=1)
28
29writer = SummaryWriter()
30net = ConvNet()
31wandb.watch(net, log_freq=2)
32for i in range(10):
33output = net(torch.ones((64, 1, 28, 28)))
34loss = F.mse_loss(output, torch.ones((64, 10)))
35output.backward(torch.ones(64, 10))
36writer.add_scalar("loss", loss / 64, i + 1)
37writer.add_image("example", torch.ones((1, 28, 28)), i + 1)
38writer.close()
39
40
41if __name__ == "__main__":
42main()
43