when-to-switch

Форк
0
43 строки · 1.3 Кб
1
from argparse import Namespace
2

3
import gym
4
import torch
5
from gym.spaces import Box
6
from sample_factory.algorithms.appo.model_utils import nonlinearity
7
from sample_factory.utils.timing import Timing
8
from torch import nn
9

10
from learning.encoder import ResnetEncoder
11
from learning.epom_config import ExperimentSettings
12

13

14
class PolicyEstimationModel(nn.Module):
15

16
    def __init__(self, cfg=None):
17
        super().__init__()
18
        if cfg is None:
19
            exp_set = ExperimentSettings()
20
            cfg = {'full_config': {'experiment_settings': exp_set.dict()}, **exp_set.dict()}
21
            cfg = Namespace(**cfg)
22

23
        full_size = 5 * 2 + 1
24
        observation_space = gym.spaces.Dict(
25
            obs=gym.spaces.Box(0.0, 1.0, shape=(3, full_size, full_size)),
26
            xy=Box(low=-1024, high=1024, shape=(2,), dtype=int),
27
            target_xy=Box(low=-1024, high=1024, shape=(2,), dtype=int),
28
        )
29

30
        self.encoder = ResnetEncoder(cfg, observation_space, Timing())
31
        self.value_head = nn.Sequential(
32
            nn.Linear(self.encoder.get_encoder_out_size(), 512),
33
            nonlinearity(cfg),
34
            nn.Linear(512, 512),
35
            nonlinearity(cfg),
36
            nn.Linear(512, 1),
37
        )
38

39
    def forward(self, x):
40
        x = self.encoder(x)
41
        x = self.value_head(x)
42
        x = torch.squeeze(x, 1)
43
        return x
44

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.