google-research
162 строки · 5.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Contrastive RL networks definition."""
17import dataclasses
18from typing import Optional, Tuple, Callable
19
20from acme import specs
21from acme.agents.jax import actor_core as actor_core_lib
22from acme.jax import networks as networks_lib
23from acme.jax import utils
24import haiku as hk
25import jax
26import jax.numpy as jnp
27import numpy as np
28
29
30@dataclasses.dataclass
31class ContrastiveNetworks:
32"""Network and pure functions for the Contrastive RL agent."""
33policy_network: networks_lib.FeedForwardNetwork
34q_network: networks_lib.FeedForwardNetwork
35log_prob: networks_lib.LogProbFn
36repr_fn: Callable[Ellipsis, networks_lib.NetworkOutput]
37sample: networks_lib.SampleFn
38sample_eval: Optional[networks_lib.SampleFn] = None
39
40
41def apply_policy_and_sample(
42networks,
43eval_mode = False):
44"""Returns a function that computes actions."""
45sample_fn = networks.sample if not eval_mode else networks.sample_eval
46if not sample_fn:
47raise ValueError('sample function is not provided')
48
49def apply_and_sample(params, key, obs):
50return sample_fn(networks.policy_network.apply(params, obs), key)
51return apply_and_sample
52
53
54def make_networks(
55spec,
56obs_dim,
57repr_dim = 64,
58repr_norm = False,
59repr_norm_temp = True,
60hidden_layer_sizes = (256, 256),
61actor_min_std = 1e-6,
62twin_q = False,
63use_image_obs = False):
64"""Creates networks used by the agent."""
65
66num_dimensions = np.prod(spec.actions.shape, dtype=int)
67TORSO = networks_lib.AtariTorso # pylint: disable=invalid-name
68
69def _unflatten_obs(obs):
70state = jnp.reshape(obs[:, :obs_dim], (-1, 64, 64, 3)) / 255.0
71goal = jnp.reshape(obs[:, obs_dim:], (-1, 64, 64, 3)) / 255.0
72return state, goal
73
74def _repr_fn(obs, action, hidden=None):
75# The optional input hidden is the image representations. We include this
76# as an input for the second Q value when twin_q = True, so that the two Q
77# values use the same underlying image representation.
78if hidden is None:
79if use_image_obs:
80state, goal = _unflatten_obs(obs)
81img_encoder = TORSO()
82state = img_encoder(state)
83goal = img_encoder(goal)
84else:
85state = obs[:, :obs_dim]
86goal = obs[:, obs_dim:]
87else:
88state, goal = hidden
89
90sa_encoder = hk.nets.MLP(
91list(hidden_layer_sizes) + [repr_dim],
92w_init=hk.initializers.VarianceScaling(1.0, 'fan_avg', 'uniform'),
93activation=jax.nn.relu,
94name='sa_encoder')
95sa_repr = sa_encoder(jnp.concatenate([state, action], axis=-1))
96
97g_encoder = hk.nets.MLP(
98list(hidden_layer_sizes) + [repr_dim],
99w_init=hk.initializers.VarianceScaling(1.0, 'fan_avg', 'uniform'),
100activation=jax.nn.relu,
101name='g_encoder')
102g_repr = g_encoder(goal)
103
104if repr_norm:
105sa_repr = sa_repr / jnp.linalg.norm(sa_repr, axis=1, keepdims=True)
106g_repr = g_repr / jnp.linalg.norm(g_repr, axis=1, keepdims=True)
107
108if repr_norm_temp:
109log_scale = hk.get_parameter('repr_log_scale', [], dtype=sa_repr.dtype,
110init=jnp.zeros)
111sa_repr = sa_repr / jnp.exp(log_scale)
112return sa_repr, g_repr, (state, goal)
113
114def _combine_repr(sa_repr, g_repr):
115return jax.numpy.einsum('ik,jk->ij', sa_repr, g_repr)
116
117def _critic_fn(obs, action):
118sa_repr, g_repr, hidden = _repr_fn(obs, action)
119outer = _combine_repr(sa_repr, g_repr)
120if twin_q:
121sa_repr2, g_repr2, _ = _repr_fn(obs, action, hidden=hidden)
122outer2 = _combine_repr(sa_repr2, g_repr2)
123# outer.shape = [batch_size, batch_size, 2]
124outer = jnp.stack([outer, outer2], axis=-1)
125return outer
126
127def _actor_fn(obs):
128if use_image_obs:
129state, goal = _unflatten_obs(obs)
130obs = jnp.concatenate([state, goal], axis=-1)
131obs = TORSO()(obs)
132network = hk.Sequential([
133hk.nets.MLP(
134list(hidden_layer_sizes),
135w_init=hk.initializers.VarianceScaling(1.0, 'fan_in', 'uniform'),
136activation=jax.nn.relu,
137activate_final=True),
138networks_lib.NormalTanhDistribution(num_dimensions,
139min_scale=actor_min_std),
140])
141return network(obs)
142
143policy = hk.without_apply_rng(hk.transform(_actor_fn))
144critic = hk.without_apply_rng(hk.transform(_critic_fn))
145repr_fn = hk.without_apply_rng(hk.transform(_repr_fn))
146
147# Create dummy observations and actions to create network parameters.
148dummy_action = utils.zeros_like(spec.actions)
149dummy_obs = utils.zeros_like(spec.observations)
150dummy_action = utils.add_batch_dim(dummy_action)
151dummy_obs = utils.add_batch_dim(dummy_obs)
152
153return ContrastiveNetworks(
154policy_network=networks_lib.FeedForwardNetwork(
155lambda key: policy.init(key, dummy_obs), policy.apply),
156q_network=networks_lib.FeedForwardNetwork(
157lambda key: critic.init(key, dummy_obs, dummy_action), critic.apply),
158repr_fn=repr_fn.apply,
159log_prob=lambda params, actions: params.log_prob(actions),
160sample=lambda params, key: params.sample(seed=key),
161sample_eval=lambda params, key: params.mode(),
162)
163