google-research

Форк
0
162 строки · 5.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Contrastive RL networks definition."""
17
import dataclasses
18
from typing import Optional, Tuple, Callable
19

20
from acme import specs
21
from acme.agents.jax import actor_core as actor_core_lib
22
from acme.jax import networks as networks_lib
23
from acme.jax import utils
24
import haiku as hk
25
import jax
26
import jax.numpy as jnp
27
import numpy as np
28

29

30
@dataclasses.dataclass
31
class ContrastiveNetworks:
32
  """Network and pure functions for the Contrastive RL agent."""
33
  policy_network: networks_lib.FeedForwardNetwork
34
  q_network: networks_lib.FeedForwardNetwork
35
  log_prob: networks_lib.LogProbFn
36
  repr_fn: Callable[Ellipsis, networks_lib.NetworkOutput]
37
  sample: networks_lib.SampleFn
38
  sample_eval: Optional[networks_lib.SampleFn] = None
39

40

41
def apply_policy_and_sample(
42
    networks,
43
    eval_mode = False):
44
  """Returns a function that computes actions."""
45
  sample_fn = networks.sample if not eval_mode else networks.sample_eval
46
  if not sample_fn:
47
    raise ValueError('sample function is not provided')
48

49
  def apply_and_sample(params, key, obs):
50
    return sample_fn(networks.policy_network.apply(params, obs), key)
51
  return apply_and_sample
52

53

54
def make_networks(
55
    spec,
56
    obs_dim,
57
    repr_dim = 64,
58
    repr_norm = False,
59
    repr_norm_temp = True,
60
    hidden_layer_sizes = (256, 256),
61
    actor_min_std = 1e-6,
62
    twin_q = False,
63
    use_image_obs = False):
64
  """Creates networks used by the agent."""
65

66
  num_dimensions = np.prod(spec.actions.shape, dtype=int)
67
  TORSO = networks_lib.AtariTorso  # pylint: disable=invalid-name
68

69
  def _unflatten_obs(obs):
70
    state = jnp.reshape(obs[:, :obs_dim], (-1, 64, 64, 3)) / 255.0
71
    goal = jnp.reshape(obs[:, obs_dim:], (-1, 64, 64, 3)) / 255.0
72
    return state, goal
73

74
  def _repr_fn(obs, action, hidden=None):
75
    # The optional input hidden is the image representations. We include this
76
    # as an input for the second Q value when twin_q = True, so that the two Q
77
    # values use the same underlying image representation.
78
    if hidden is None:
79
      if use_image_obs:
80
        state, goal = _unflatten_obs(obs)
81
        img_encoder = TORSO()
82
        state = img_encoder(state)
83
        goal = img_encoder(goal)
84
      else:
85
        state = obs[:, :obs_dim]
86
        goal = obs[:, obs_dim:]
87
    else:
88
      state, goal = hidden
89

90
    sa_encoder = hk.nets.MLP(
91
        list(hidden_layer_sizes) + [repr_dim],
92
        w_init=hk.initializers.VarianceScaling(1.0, 'fan_avg', 'uniform'),
93
        activation=jax.nn.relu,
94
        name='sa_encoder')
95
    sa_repr = sa_encoder(jnp.concatenate([state, action], axis=-1))
96

97
    g_encoder = hk.nets.MLP(
98
        list(hidden_layer_sizes) + [repr_dim],
99
        w_init=hk.initializers.VarianceScaling(1.0, 'fan_avg', 'uniform'),
100
        activation=jax.nn.relu,
101
        name='g_encoder')
102
    g_repr = g_encoder(goal)
103

104
    if repr_norm:
105
      sa_repr = sa_repr / jnp.linalg.norm(sa_repr, axis=1, keepdims=True)
106
      g_repr = g_repr / jnp.linalg.norm(g_repr, axis=1, keepdims=True)
107

108
      if repr_norm_temp:
109
        log_scale = hk.get_parameter('repr_log_scale', [], dtype=sa_repr.dtype,
110
                                     init=jnp.zeros)
111
        sa_repr = sa_repr / jnp.exp(log_scale)
112
    return sa_repr, g_repr, (state, goal)
113

114
  def _combine_repr(sa_repr, g_repr):
115
    return jax.numpy.einsum('ik,jk->ij', sa_repr, g_repr)
116

117
  def _critic_fn(obs, action):
118
    sa_repr, g_repr, hidden = _repr_fn(obs, action)
119
    outer = _combine_repr(sa_repr, g_repr)
120
    if twin_q:
121
      sa_repr2, g_repr2, _ = _repr_fn(obs, action, hidden=hidden)
122
      outer2 = _combine_repr(sa_repr2, g_repr2)
123
      # outer.shape = [batch_size, batch_size, 2]
124
      outer = jnp.stack([outer, outer2], axis=-1)
125
    return outer
126

127
  def _actor_fn(obs):
128
    if use_image_obs:
129
      state, goal = _unflatten_obs(obs)
130
      obs = jnp.concatenate([state, goal], axis=-1)
131
      obs = TORSO()(obs)
132
    network = hk.Sequential([
133
        hk.nets.MLP(
134
            list(hidden_layer_sizes),
135
            w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
136
            activation=jax.nn.relu,
137
            activate_final=True),
138
        networks_lib.NormalTanhDistribution(num_dimensions,
139
                                            min_scale=actor_min_std),
140
    ])
141
    return network(obs)
142

143
  policy = hk.without_apply_rng(hk.transform(_actor_fn))
144
  critic = hk.without_apply_rng(hk.transform(_critic_fn))
145
  repr_fn = hk.without_apply_rng(hk.transform(_repr_fn))
146

147
  # Create dummy observations and actions to create network parameters.
148
  dummy_action = utils.zeros_like(spec.actions)
149
  dummy_obs = utils.zeros_like(spec.observations)
150
  dummy_action = utils.add_batch_dim(dummy_action)
151
  dummy_obs = utils.add_batch_dim(dummy_obs)
152

153
  return ContrastiveNetworks(
154
      policy_network=networks_lib.FeedForwardNetwork(
155
          lambda key: policy.init(key, dummy_obs), policy.apply),
156
      q_network=networks_lib.FeedForwardNetwork(
157
          lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply),
158
      repr_fn=repr_fn.apply,
159
      log_prob=lambda params, actions: params.log_prob(actions),
160
      sample=lambda params, key: params.sample(seed=key),
161
      sample_eval=lambda params, key: params.mode(),
162
      )
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.