google-research

Форк
0
/
train_acme.py 
220 строк · 7.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# python3
17
r"""Example running contrastive RL in JAX.
18

19
Run using multi-processing (required for image-based experiments):
20
  python train_acme.py --lp_launch_type=local_mp
21

22
Run using multi-threading
23
  python lp_contrastive.py --lp_launch_type=local_mt
24

25

26
"""
27
import functools
28
from typing import Any, Dict
29

30
from absl import app
31
from absl import flags
32
import jax
33
import launchpad as lp
34
from ml_collections import config_dict
35
from xmanager import xm
36
from xmanager import xm_abc
37

38

39
from cvl_public import utils as contrastive_utils
40
from cvl_public.agents import DistributedContrastive
41
from cvl_public.config import ContrastiveConfig
42
from cvl_public.networks import make_networks
43

44
FLAGS = flags.FLAGS
45
flags.DEFINE_bool('debug', False, 'Runs training for just a few steps.')
46

47

48
@functools.lru_cache()
49
def get_env(env_name, start_index, end_index):
50
  return contrastive_utils.make_environment(env_name, start_index, end_index,
51
                                            seed=0)
52

53

54
def get_program(params):
55
  """Constructs the program."""
56

57
  env_name = params['env_name']
58
  seed = params.pop('seed')
59

60
  if params.get('use_image_obs', False) and not params.get('local', False):
61
    print('WARNING: overwriting parameters for image-based tasks.')
62
    params['num_sgd_steps_per_step'] = 16
63
    params['prefetch_size'] = 16
64
    params['num_actors'] = 10
65

66
  if env_name.startswith('offline_'):
67
    # No actors needed for the offline RL experiments. Evaluation is
68
    # handled separately.
69
    params['num_actors'] = 0
70

71
  config = ContrastiveConfig(**params)
72

73
  env_factory = lambda seed: contrastive_utils.make_environment(  # pylint: disable=g-long-lambda
74
      env_name, config.start_index, config.end_index, seed)
75

76
  env_factory_no_extra = lambda seed: env_factory(seed)[0]  # Remove obs_dim.
77
  environment, obs_dim = get_env(env_name, config.start_index,
78
                                 config.end_index)
79
  assert (environment.action_spec().minimum == -1).all()
80
  assert (environment.action_spec().maximum == 1).all()
81
  config.obs_dim = obs_dim
82
  config.max_episode_steps = getattr(environment, '_step_limit') + 1
83
  if env_name == 'offline_ant_umaze_diverse':
84
    # This environment terminates after 700 steps, but demos have 1000 steps.
85
    config.max_episode_steps = 1000
86
  network_factory = functools.partial(
87
      make_networks,
88
      repr_dim=config.repr_dim,
89
      repr_norm=config.repr_norm,
90
      twin_q=config.twin_q,
91
      use_image_obs=config.use_image_obs,
92
      hidden_layer_sizes=config.hidden_layer_sizes)
93

94
  agent = DistributedContrastive(
95
      seed=seed,
96
      environment_factory=env_factory_no_extra,
97
      network_factory=network_factory,
98
      config=config,
99
      num_actors=config.num_actors,
100
      log_to_bigtable=True,
101
      max_number_of_steps=config.max_number_of_steps)
102
  return agent.build()
103

104

105
def main(_):
106
  # Create experiment description.
107

108
  # 1. Select an environment.
109
  # Supported environments:
110
  #   Metaworld: sawyer_{push,drawer,bin,window}
111
  #   OpenAI Gym Fetch: fetch_{reach,push}
112
  #   D4RL AntMaze: ant_{umaze,,medium,large},
113
  #   2D nav: point_{Small,Cross,FourRooms,U,Spiral11x11,Maze11x11}
114
  # Image observation environments:
115
  #   Metaworld: sawyer_image_{push,drawer,bin,window}
116
  #   OpenAI Gym Fetch: fetch_{reach,push}_image
117
  #   2D nav: point_image_{Small,Cross,FourRooms,U,Spiral11x11,Maze11x11}
118
  # Offline environments:
119
  #   antmaze: offline_ant_{umaze,umaze_diverse,
120
  #                             medium_play,medium_diverse,
121
  #                             large_play,large_diverse}
122
  # env_name = 'fetch_reach'
123
  env_name = ['offline_metaworld_pick-place-v2']
124
  # env_name = [
125
  #     'offline_ant_umaze', 'offline_ant_umaze_diverse',
126
  #     'offline_ant_medium_play', 'offline_ant_medium_diverse',
127
  #     'offline_ant_large_play', 'offline_ant_large_diverse'
128
  # ]
129
  # env_name = [
130
  #     'offline_halfcheetah-medium-v2',
131
  #     'offline_halfcheetah-medium-replay-v2',
132
  #     'offline_walker2d-medium-v2',
133
  #     'offline_walker2d-medium-replay-v2',
134
  #     'offline_hopper-medium-v2',
135
  #     'offline_hopper-medium-replay-v2'
136
  # ]
137
  params = {
138
      'seed': 0,
139
      'use_random_actor': True,
140
      'entropy_coefficient': None if 'image' in env_name[0] else 0.0,
141
      'env_name': env_name,
142
      # For online RL experiments, max_number_of_steps is the number of
143
      # environment steps. For offline RL experiments, this is the number of
144
      # gradient steps.
145
      'max_number_of_steps': 1_000_000,
146
      'use_image_obs': 'image' in env_name[0],
147
  }
148
  if 'ant_' in env_name[0]:
149
    params['end_index'] = 2
150

151
  # 2. Select an algorithm. The currently-supported algorithms are:
152
  # contrastive_nce, contrastive_cpc, c_learning, nce+c_learning, gcbc.
153
  # Many other algorithms can be implemented by passing other parameters
154
  # or adding a few lines of code.
155
  alg = 'contrastive_nce'
156
  if alg == 'contrastive_nce':
157
    pass  # Just use the default hyperparameters
158
  elif alg == 'contrastive_cpc':
159
    params['use_cpc'] = True
160
  elif alg == 'c_learning':
161
    params['use_td'] = True
162
    params['twin_q'] = True
163
  elif alg == 'nce+c_learning':
164
    params['use_td'] = True
165
    params['twin_q'] = True
166
    params['add_mc_to_td'] = True
167
  elif alg == 'gcbc':
168
    params['use_gcbc'] = True
169
  else:
170
    raise NotImplementedError('Unknown method: %s' % alg)
171

172
  # For the offline RL experiments, modify some hyperparameters.
173
  if env_name[0].startswith('offline_'):
174
    params.update({
175
        # Effectively remove the rate-limiter by using very large values.
176
        'samples_per_insert': 1_000_000,
177
        'samples_per_insert_tolerance_rate': 100_000_000.0,
178
        # For the actor update, only use future states as goals.
179
        'random_goals': 0.0,
180
        'bc_coef': 0.05,  # Add a behavioral cloning term to the actor.
181
        'twin_q': False,  # Learn two critics, and take the minimum.
182
        'batch_size': 1024,  # Increase the batch size 256 --> 1024.
183
        'repr_dim': 16,  # Decrease the representation size 64 --> 16.
184
        # Increase the policy network size (256, 256) --> (1024, 1024)
185
        'hidden_layer_sizes': (1024, 1024),
186
    })
187

188
  # 3. Select compute parameters. The default parameters are already tuned, so
189
  # use this mainly for debugging.
190
  if FLAGS.debug:
191
    params.update({
192
        'min_replay_size': 2_000,
193
        'local': True,
194
        'num_sgd_steps_per_step': 1,
195
        'prefetch_size': 1,
196
        'num_actors': 1,
197
        'batch_size': 32,
198
        'max_number_of_steps': 10_000,
199
        'hidden_layer_sizes': (32, 32),
200
    })
201

202
  sweep_params = []
203
  for k, v in params.items():
204
    if not isinstance(v, list):
205
      v = [v]
206
    sweep_params.append(hyper.sweep(k, v))
207
  sweep_params = hyper.product(sweep_params)
208

209
  programs = []
210
  program_params = []
211
  for hypers in sweep_params:
212
    hypers = dict(config_dict.ConfigDict(hypers))
213
    program = get_program(hypers)
214
    programs.append(program)
215
    program_params.append(hypers)
216

217
  lp.launch(programs, terminal='current_terminal')
218

219
if __name__ == '__main__':
220
  app.run(main)
221

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.