wandb

Форк
0
/
pytorch_tensorboard.py 
42 строки · 1.3 Кб
1
import torch
2
import torch.nn as nn
3
import torch.nn.functional as F  # noqa N812
4
import wandb
5
from torch.utils.tensorboard import SummaryWriter
6

7

8
def main():
9
    wandb.init(tensorboard=True)
10

11
    class ConvNet(nn.Module):
12
        def __init__(self):
13
            super().__init__()
14
            self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
15
            self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
16
            self.conv2_drop = nn.Dropout2d()
17
            self.fc1 = nn.Linear(320, 50)
18
            self.fc2 = nn.Linear(50, 10)
19

20
        def forward(self, x):
21
            x = F.relu(F.max_pool2d(self.conv1(x), 2))
22
            x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
23
            x = x.view(-1, 320)
24
            x = F.relu(self.fc1(x))
25
            x = F.dropout(x, training=self.training)
26
            x = self.fc2(x)
27
            return F.log_softmax(x, dim=1)
28

29
    writer = SummaryWriter()
30
    net = ConvNet()
31
    wandb.watch(net, log_freq=2)
32
    for i in range(10):
33
        output = net(torch.ones((64, 1, 28, 28)))
34
        loss = F.mse_loss(output, torch.ones((64, 10)))
35
        output.backward(torch.ones(64, 10))
36
        writer.add_scalar("loss", loss / 64, i + 1)
37
        writer.add_image("example", torch.ones((1, 28, 28)), i + 1)
38
    writer.close()
39

40

41
if __name__ == "__main__":
42
    main()
43

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.