when-to-switch
43 строки · 1.3 Кб
1from argparse import Namespace2
3import gym4import torch5from gym.spaces import Box6from sample_factory.algorithms.appo.model_utils import nonlinearity7from sample_factory.utils.timing import Timing8from torch import nn9
10from learning.encoder import ResnetEncoder11from learning.epom_config import ExperimentSettings12
13
14class PolicyEstimationModel(nn.Module):15
16def __init__(self, cfg=None):17super().__init__()18if cfg is None:19exp_set = ExperimentSettings()20cfg = {'full_config': {'experiment_settings': exp_set.dict()}, **exp_set.dict()}21cfg = Namespace(**cfg)22
23full_size = 5 * 2 + 124observation_space = gym.spaces.Dict(25obs=gym.spaces.Box(0.0, 1.0, shape=(3, full_size, full_size)),26xy=Box(low=-1024, high=1024, shape=(2,), dtype=int),27target_xy=Box(low=-1024, high=1024, shape=(2,), dtype=int),28)29
30self.encoder = ResnetEncoder(cfg, observation_space, Timing())31self.value_head = nn.Sequential(32nn.Linear(self.encoder.get_encoder_out_size(), 512),33nonlinearity(cfg),34nn.Linear(512, 512),35nonlinearity(cfg),36nn.Linear(512, 1),37)38
39def forward(self, x):40x = self.encoder(x)41x = self.value_head(x)42x = torch.squeeze(x, 1)43return x44