google-research

Форк
0
187 строк · 6.9 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Example running Concept PPO in meltingpot with Acme."""
17

18
from typing import Callable, Dict
19

20
from absl import app
21
from absl import flags
22
from acme import specs
23
from acme.jax import experiments
24
from acme.jax import types as jax_types
25
from acme.multiagent import types as ma_types
26
import dm_env
27
import optax
28

29
from concept_marl.experiments import helpers
30
from concept_marl.experiments.meltingpot.wrappers import wrapper_utils
31
from concept_marl.utils import builder as mp_builder
32
from concept_marl.utils import factories as mp_factories
33

34

35
_ENV_NAME = flags.DEFINE_string('env_name', 'cooking_basic',
36
                                'Name of the environment to run.')
37
_EPISODE_LENGTH = flags.DEFINE_integer('episode_length', 100,
38
                                       'Max number of steps in episode.')
39
_NUM_STEPS = flags.DEFINE_integer('num_steps', 10000,
40
                                  'Number of env steps to run training for.')
41
_EVAL_EVERY = flags.DEFINE_integer('eval_every', 1000,
42
                                   'How often to run evaluation.')
43
_SEED = flags.DEFINE_integer('seed', 0, 'Random seed.')
44
_LR_START = flags.DEFINE_float('learning_rate_start', 5e-4, 'Learning rate.')
45
_LR_END = flags.DEFINE_float('learning_rate_end', 5e-7, 'Learning rate.')
46
_LR_DECAY = flags.DEFINE_integer('learning_rate_decay_steps', 100_000,
47
                                 'Learning rate.')
48
_BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.')
49
_UNROLL_LENGTH = flags.DEFINE_integer('unroll_length', 16,
50
                                      'Unroll length for PPO.')
51
_NUM_EPOCHS = flags.DEFINE_integer('num_epochs', 5, 'Num epochs for PPO.')
52
_NUM_MINIBATCHES = flags.DEFINE_integer('num_minibatches', 32,
53
                                        'Num minibatches for PPO.')
54
_PPO_CLIPPING_EPSILON = flags.DEFINE_float('ppo_clipping_epsilon', 0.2,
55
                                           'Clipping epsilon for PPO.')
56
_ENTROPY_COST = flags.DEFINE_float('entropy_cost', 0.01,
57
                                   'Entropy cost weight for PPO.')
58
_VALUE_COST = flags.DEFINE_float('value_cost', 1.0,
59
                                 'Value cost weight for PPO.')
60
_MAX_GRAD_NORM = flags.DEFINE_float('max_gradient_norm', 0.5,
61
                                    'Global gradient clip for PPO.')
62
# concept related flags
63
_CONCEPT_COST = flags.DEFINE_float('concept_cost', 0.1,
64
                                   'Concept cost weight for PPO.')
65
# clean up flags
66
_EAT_REWARD = flags.DEFINE_float('eat_reward', 0.1,
67
                                 'Local reward for eating an apple.')
68
_CLEAN_REWARD = flags.DEFINE_float('clean_reward', 0.005,
69
                                   'Local reward for cleaning river.')
70

71

72
def get_env(env_name, seed, episode_length):
73
  """Initializes and returns meltingpot environment."""
74
  env_config = dict(
75
      env_name=env_name,
76
      action_type='flat',
77
      grayscale=False,
78
      scale_dims=(40, 40),
79
      episode_length=episode_length,
80
      seed=seed)
81

82
  # standard melting pot wrapper
83
  if 'cooking' in env_name:
84
    env_config['dense_rewards'] = True
85
    env = wrapper_utils.make_and_wrap_cooking_environment(**env_config)
86
  elif 'clean' in env_name:
87
    env_config['dense_rewards'] = True
88
    env_config['clean_reward'] = _CLEAN_REWARD.value
89
    env_config['eat_reward'] = _EAT_REWARD.value
90
    env = wrapper_utils.make_and_wrap_cleanup_environment(**env_config)
91
  elif 'capture' in env_name:
92
    env = wrapper_utils.make_and_wrap_capture_environment(**env_config)
93
  else:
94
    raise ValueError('Invalid environment choice!')
95

96
  # envlogger book-keeping
97
  env_config['env_name_for_get_env'] = env_name
98
  env_config['num_steps'] = episode_length
99
  env.n_agents = env.num_agents
100
  return env, env_config
101

102

103
def _make_environment_factory(env_name):
104

105
  def environment_factory(seed):
106
    environment, _ = get_env(env_name, seed, _EPISODE_LENGTH.value)
107
    return environment
108

109
  return environment_factory
110

111

112
def _make_network_factory(
113
    agent_types
114
):
115
  """Returns a network factory for meltingpot experiments."""
116

117
  def network_factory(
118
      environment_spec):
119
    return mp_factories.network_factory(
120
        environment_spec,
121
        agent_types,
122
        init_network_fn=helpers.init_default_meltingpot_network)
123

124
  return network_factory
125

126

127
def build_experiment_config():
128
  """Returns a config for meltingpot experiments."""
129

130
  # init environment
131
  environment_factory = _make_environment_factory(_ENV_NAME.value)
132
  environment = environment_factory(_SEED.value)
133

134
  # init learning rate schedule
135
  learning_rate = optax.polynomial_schedule(
136
      init_value=_LR_START.value,
137
      end_value=_LR_END.value,
138
      power=1,
139
      transition_steps=_LR_DECAY.value)
140

141
  # init Concept PPO agent
142
  agent_types = {
143
      str(i): mp_factories.DefaultSupportedAgent.CONCEPT_PPO
144
      for i in range(environment.num_agents)  # pytype: disable=attribute-error
145
  }
146
  config_overrides = {  # pylint: disable=g-complex-comprehension
147
      agent_id: {
148
          'learning_rate': learning_rate,
149
          'batch_size': _BATCH_SIZE.value,
150
          'unroll_length': _UNROLL_LENGTH.value,
151
          'num_minibatches': _NUM_MINIBATCHES.value,
152
          'num_epochs': _NUM_EPOCHS.value,
153
          'ppo_clipping_epsilon': _PPO_CLIPPING_EPSILON.value,
154
          'entropy_cost': _ENTROPY_COST.value,
155
          'value_cost': _VALUE_COST.value,
156
          'concept_cost': _CONCEPT_COST.value,
157
          'max_gradient_norm': _MAX_GRAD_NORM.value,
158
          'clip_value': False,
159
      } for agent_id in agent_types.keys()
160
  }
161

162
  # init configs from agents
163
  configs = mp_factories.default_config_factory(agent_types, _BATCH_SIZE.value,
164
                                                config_overrides)
165

166
  builder = mp_builder.DecentralizedMultiAgentBuilder(
167
      agent_types=agent_types,
168
      agent_configs=configs,
169
      init_policy_network_fn=mp_factories.init_default_policy_network,
170
      init_builder_fn=mp_factories.init_default_builder)
171

172
  return experiments.ExperimentConfig(
173
      builder=builder,
174
      environment_factory=environment_factory,
175
      network_factory=_make_network_factory(agent_types=agent_types),
176
      seed=_SEED.value,
177
      max_num_actor_steps=_NUM_STEPS.value)
178

179

180
def main(_):
181
  config = build_experiment_config()
182
  experiments.run_experiment(
183
      experiment=config, eval_every=_EVAL_EVERY.value, num_eval_episodes=5)
184

185

186
if __name__ == '__main__':
187
  app.run(main)
188

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.