google-research
187 строк · 6.9 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Example running Concept PPO in meltingpot with Acme."""
17
18from typing import Callable, Dict
19
20from absl import app
21from absl import flags
22from acme import specs
23from acme.jax import experiments
24from acme.jax import types as jax_types
25from acme.multiagent import types as ma_types
26import dm_env
27import optax
28
29from concept_marl.experiments import helpers
30from concept_marl.experiments.meltingpot.wrappers import wrapper_utils
31from concept_marl.utils import builder as mp_builder
32from concept_marl.utils import factories as mp_factories
33
34
35_ENV_NAME = flags.DEFINE_string('env_name', 'cooking_basic',
36'Name of the environment to run.')
37_EPISODE_LENGTH = flags.DEFINE_integer('episode_length', 100,
38'Max number of steps in episode.')
39_NUM_STEPS = flags.DEFINE_integer('num_steps', 10000,
40'Number of env steps to run training for.')
41_EVAL_EVERY = flags.DEFINE_integer('eval_every', 1000,
42'How often to run evaluation.')
43_SEED = flags.DEFINE_integer('seed', 0, 'Random seed.')
44_LR_START = flags.DEFINE_float('learning_rate_start', 5e-4, 'Learning rate.')
45_LR_END = flags.DEFINE_float('learning_rate_end', 5e-7, 'Learning rate.')
46_LR_DECAY = flags.DEFINE_integer('learning_rate_decay_steps', 100_000,
47'Learning rate.')
48_BATCH_SIZE = flags.DEFINE_integer('batch_size', 256, 'Batch size.')
49_UNROLL_LENGTH = flags.DEFINE_integer('unroll_length', 16,
50'Unroll length for PPO.')
51_NUM_EPOCHS = flags.DEFINE_integer('num_epochs', 5, 'Num epochs for PPO.')
52_NUM_MINIBATCHES = flags.DEFINE_integer('num_minibatches', 32,
53'Num minibatches for PPO.')
54_PPO_CLIPPING_EPSILON = flags.DEFINE_float('ppo_clipping_epsilon', 0.2,
55'Clipping epsilon for PPO.')
56_ENTROPY_COST = flags.DEFINE_float('entropy_cost', 0.01,
57'Entropy cost weight for PPO.')
58_VALUE_COST = flags.DEFINE_float('value_cost', 1.0,
59'Value cost weight for PPO.')
60_MAX_GRAD_NORM = flags.DEFINE_float('max_gradient_norm', 0.5,
61'Global gradient clip for PPO.')
62# concept related flags
63_CONCEPT_COST = flags.DEFINE_float('concept_cost', 0.1,
64'Concept cost weight for PPO.')
65# clean up flags
66_EAT_REWARD = flags.DEFINE_float('eat_reward', 0.1,
67'Local reward for eating an apple.')
68_CLEAN_REWARD = flags.DEFINE_float('clean_reward', 0.005,
69'Local reward for cleaning river.')
70
71
72def get_env(env_name, seed, episode_length):
73"""Initializes and returns meltingpot environment."""
74env_config = dict(
75env_name=env_name,
76action_type='flat',
77grayscale=False,
78scale_dims=(40, 40),
79episode_length=episode_length,
80seed=seed)
81
82# standard melting pot wrapper
83if 'cooking' in env_name:
84env_config['dense_rewards'] = True
85env = wrapper_utils.make_and_wrap_cooking_environment(**env_config)
86elif 'clean' in env_name:
87env_config['dense_rewards'] = True
88env_config['clean_reward'] = _CLEAN_REWARD.value
89env_config['eat_reward'] = _EAT_REWARD.value
90env = wrapper_utils.make_and_wrap_cleanup_environment(**env_config)
91elif 'capture' in env_name:
92env = wrapper_utils.make_and_wrap_capture_environment(**env_config)
93else:
94raise ValueError('Invalid environment choice!')
95
96# envlogger book-keeping
97env_config['env_name_for_get_env'] = env_name
98env_config['num_steps'] = episode_length
99env.n_agents = env.num_agents
100return env, env_config
101
102
103def _make_environment_factory(env_name):
104
105def environment_factory(seed):
106environment, _ = get_env(env_name, seed, _EPISODE_LENGTH.value)
107return environment
108
109return environment_factory
110
111
112def _make_network_factory(
113agent_types
114):
115"""Returns a network factory for meltingpot experiments."""
116
117def network_factory(
118environment_spec):
119return mp_factories.network_factory(
120environment_spec,
121agent_types,
122init_network_fn=helpers.init_default_meltingpot_network)
123
124return network_factory
125
126
127def build_experiment_config():
128"""Returns a config for meltingpot experiments."""
129
130# init environment
131environment_factory = _make_environment_factory(_ENV_NAME.value)
132environment = environment_factory(_SEED.value)
133
134# init learning rate schedule
135learning_rate = optax.polynomial_schedule(
136init_value=_LR_START.value,
137end_value=_LR_END.value,
138power=1,
139transition_steps=_LR_DECAY.value)
140
141# init Concept PPO agent
142agent_types = {
143str(i): mp_factories.DefaultSupportedAgent.CONCEPT_PPO
144for i in range(environment.num_agents) # pytype: disable=attribute-error
145}
146config_overrides = { # pylint: disable=g-complex-comprehension
147agent_id: {
148'learning_rate': learning_rate,
149'batch_size': _BATCH_SIZE.value,
150'unroll_length': _UNROLL_LENGTH.value,
151'num_minibatches': _NUM_MINIBATCHES.value,
152'num_epochs': _NUM_EPOCHS.value,
153'ppo_clipping_epsilon': _PPO_CLIPPING_EPSILON.value,
154'entropy_cost': _ENTROPY_COST.value,
155'value_cost': _VALUE_COST.value,
156'concept_cost': _CONCEPT_COST.value,
157'max_gradient_norm': _MAX_GRAD_NORM.value,
158'clip_value': False,
159} for agent_id in agent_types.keys()
160}
161
162# init configs from agents
163configs = mp_factories.default_config_factory(agent_types, _BATCH_SIZE.value,
164config_overrides)
165
166builder = mp_builder.DecentralizedMultiAgentBuilder(
167agent_types=agent_types,
168agent_configs=configs,
169init_policy_network_fn=mp_factories.init_default_policy_network,
170init_builder_fn=mp_factories.init_default_builder)
171
172return experiments.ExperimentConfig(
173builder=builder,
174environment_factory=environment_factory,
175network_factory=_make_network_factory(agent_types=agent_types),
176seed=_SEED.value,
177max_num_actor_steps=_NUM_STEPS.value)
178
179
180def main(_):
181config = build_experiment_config()
182experiments.run_experiment(
183experiment=config, eval_every=_EVAL_EVERY.value, num_eval_episodes=5)
184
185
186if __name__ == '__main__':
187app.run(main)
188