google-research
220 строк · 7.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# python3
17r"""Example running contrastive RL in JAX.
18
19Run using multi-processing (required for image-based experiments):
20python train_acme.py --lp_launch_type=local_mp
21
22Run using multi-threading
23python lp_contrastive.py --lp_launch_type=local_mt
24
25
26"""
27import functools
28from typing import Any, Dict
29
30from absl import app
31from absl import flags
32import jax
33import launchpad as lp
34from ml_collections import config_dict
35from xmanager import xm
36from xmanager import xm_abc
37
38
39from cvl_public import utils as contrastive_utils
40from cvl_public.agents import DistributedContrastive
41from cvl_public.config import ContrastiveConfig
42from cvl_public.networks import make_networks
43
44FLAGS = flags.FLAGS
45flags.DEFINE_bool('debug', False, 'Runs training for just a few steps.')
46
47
48@functools.lru_cache()
49def get_env(env_name, start_index, end_index):
50return contrastive_utils.make_environment(env_name, start_index, end_index,
51seed=0)
52
53
54def get_program(params):
55"""Constructs the program."""
56
57env_name = params['env_name']
58seed = params.pop('seed')
59
60if params.get('use_image_obs', False) and not params.get('local', False):
61print('WARNING: overwriting parameters for image-based tasks.')
62params['num_sgd_steps_per_step'] = 16
63params['prefetch_size'] = 16
64params['num_actors'] = 10
65
66if env_name.startswith('offline_'):
67# No actors needed for the offline RL experiments. Evaluation is
68# handled separately.
69params['num_actors'] = 0
70
71config = ContrastiveConfig(**params)
72
73env_factory = lambda seed: contrastive_utils.make_environment( # pylint: disable=g-long-lambda
74env_name, config.start_index, config.end_index, seed)
75
76env_factory_no_extra = lambda seed: env_factory(seed)[0] # Remove obs_dim.
77environment, obs_dim = get_env(env_name, config.start_index,
78config.end_index)
79assert (environment.action_spec().minimum == -1).all()
80assert (environment.action_spec().maximum == 1).all()
81config.obs_dim = obs_dim
82config.max_episode_steps = getattr(environment, '_step_limit') + 1
83if env_name == 'offline_ant_umaze_diverse':
84# This environment terminates after 700 steps, but demos have 1000 steps.
85config.max_episode_steps = 1000
86network_factory = functools.partial(
87make_networks,
88repr_dim=config.repr_dim,
89repr_norm=config.repr_norm,
90twin_q=config.twin_q,
91use_image_obs=config.use_image_obs,
92hidden_layer_sizes=config.hidden_layer_sizes)
93
94agent = DistributedContrastive(
95seed=seed,
96environment_factory=env_factory_no_extra,
97network_factory=network_factory,
98config=config,
99num_actors=config.num_actors,
100log_to_bigtable=True,
101max_number_of_steps=config.max_number_of_steps)
102return agent.build()
103
104
105def main(_):
106# Create experiment description.
107
108# 1. Select an environment.
109# Supported environments:
110# Metaworld: sawyer_{push,drawer,bin,window}
111# OpenAI Gym Fetch: fetch_{reach,push}
112# D4RL AntMaze: ant_{umaze,,medium,large},
113# 2D nav: point_{Small,Cross,FourRooms,U,Spiral11x11,Maze11x11}
114# Image observation environments:
115# Metaworld: sawyer_image_{push,drawer,bin,window}
116# OpenAI Gym Fetch: fetch_{reach,push}_image
117# 2D nav: point_image_{Small,Cross,FourRooms,U,Spiral11x11,Maze11x11}
118# Offline environments:
119# antmaze: offline_ant_{umaze,umaze_diverse,
120# medium_play,medium_diverse,
121# large_play,large_diverse}
122# env_name = 'fetch_reach'
123env_name = ['offline_metaworld_pick-place-v2']
124# env_name = [
125# 'offline_ant_umaze', 'offline_ant_umaze_diverse',
126# 'offline_ant_medium_play', 'offline_ant_medium_diverse',
127# 'offline_ant_large_play', 'offline_ant_large_diverse'
128# ]
129# env_name = [
130# 'offline_halfcheetah-medium-v2',
131# 'offline_halfcheetah-medium-replay-v2',
132# 'offline_walker2d-medium-v2',
133# 'offline_walker2d-medium-replay-v2',
134# 'offline_hopper-medium-v2',
135# 'offline_hopper-medium-replay-v2'
136# ]
137params = {
138'seed': 0,
139'use_random_actor': True,
140'entropy_coefficient': None if 'image' in env_name[0] else 0.0,
141'env_name': env_name,
142# For online RL experiments, max_number_of_steps is the number of
143# environment steps. For offline RL experiments, this is the number of
144# gradient steps.
145'max_number_of_steps': 1_000_000,
146'use_image_obs': 'image' in env_name[0],
147}
148if 'ant_' in env_name[0]:
149params['end_index'] = 2
150
151# 2. Select an algorithm. The currently-supported algorithms are:
152# contrastive_nce, contrastive_cpc, c_learning, nce+c_learning, gcbc.
153# Many other algorithms can be implemented by passing other parameters
154# or adding a few lines of code.
155alg = 'contrastive_nce'
156if alg == 'contrastive_nce':
157pass # Just use the default hyperparameters
158elif alg == 'contrastive_cpc':
159params['use_cpc'] = True
160elif alg == 'c_learning':
161params['use_td'] = True
162params['twin_q'] = True
163elif alg == 'nce+c_learning':
164params['use_td'] = True
165params['twin_q'] = True
166params['add_mc_to_td'] = True
167elif alg == 'gcbc':
168params['use_gcbc'] = True
169else:
170raise NotImplementedError('Unknown method: %s' % alg)
171
172# For the offline RL experiments, modify some hyperparameters.
173if env_name[0].startswith('offline_'):
174params.update({
175# Effectively remove the rate-limiter by using very large values.
176'samples_per_insert': 1_000_000,
177'samples_per_insert_tolerance_rate': 100_000_000.0,
178# For the actor update, only use future states as goals.
179'random_goals': 0.0,
180'bc_coef': 0.05, # Add a behavioral cloning term to the actor.
181'twin_q': False, # Learn two critics, and take the minimum.
182'batch_size': 1024, # Increase the batch size 256 --> 1024.
183'repr_dim': 16, # Decrease the representation size 64 --> 16.
184# Increase the policy network size (256, 256) --> (1024, 1024)
185'hidden_layer_sizes': (1024, 1024),
186})
187
188# 3. Select compute parameters. The default parameters are already tuned, so
189# use this mainly for debugging.
190if FLAGS.debug:
191params.update({
192'min_replay_size': 2_000,
193'local': True,
194'num_sgd_steps_per_step': 1,
195'prefetch_size': 1,
196'num_actors': 1,
197'batch_size': 32,
198'max_number_of_steps': 10_000,
199'hidden_layer_sizes': (32, 32),
200})
201
202sweep_params = []
203for k, v in params.items():
204if not isinstance(v, list):
205v = [v]
206sweep_params.append(hyper.sweep(k, v))
207sweep_params = hyper.product(sweep_params)
208
209programs = []
210program_params = []
211for hypers in sweep_params:
212hypers = dict(config_dict.ConfigDict(hypers))
213program = get_program(hypers)
214programs.append(program)
215program_params.append(hypers)
216
217lp.launch(programs, terminal='current_terminal')
218
219if __name__ == '__main__':
220app.run(main)
221