google-research
109 строк · 3.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16r"""Train an agent.
17
18"""
19import os20
21
22from absl import app23from absl import flags24
25from dopamine.discrete_domains import run_experiment26import tensorflow.compat.v1 as tf27
28from experience_replay import run_experience_replay_experiment29
30
31flags.DEFINE_string('base_dir', None,32'Base directory to host all required sub-directories.')33flags.DEFINE_multi_string(34'gin_files', [], 'List of paths to gin configuration files (e.g.'35'"third_party/py/dopamine/agents/dqn/dqn.gin").')36flags.DEFINE_multi_string(37'gin_bindings', [],38'Gin bindings to override the values set in the config files '39'(e.g. "DQNAgent.epsilon_train=0.1",'40' "create_atari_environment.game_name="Pong"").')41flags.DEFINE_string(42'schedule', 'continuous_train_and_eval',43'The schedule with which to run the experiment and choose an appropriate '44'Runner. Supported choices are '45'{continuous_train, eval, continuous_train_and_eval}.')46
47
48FLAGS = flags.FLAGS49
50
51
52def create_runner(base_dir, create_agent_fn,53schedule='continuous_train_and_eval'):54"""Creates an experiment Runner.55
56TODO(b/): Figure out the right idiom to create a Runner. The current mechanism
57of using a number of flags will not scale and is not elegant.
58
59Args:
60base_dir: Base directory for hosting all subdirectories.
61create_agent_fn: A function that takes as args a Tensorflow session and a
62Gym Atari 2600 environment, and returns an agent.
63schedule: string, which type of Runner to use.
64
65Returns:
66runner: A `run_experiment.Runner` like object.
67
68Raises:
69ValueError: When an unknown schedule is encountered.
70"""
71assert base_dir is not None72
73# Continuously runs training and eval till max num_iterations is hit.74if schedule == 'continuous_train_and_eval':75return run_experience_replay_experiment.ElephantRunner(76base_dir, create_agent_fn)77
78else:79raise ValueError('Unknown schedule: {}'.format(schedule))80
81
82def launch_experiment(create_runner_fn, create_agent_fn):83"""Launches the experiment.84
85Args:
86create_runner_fn: A function that takes as args a base directory and a
87function for creating an agent and returns a `Runner` like object.
88create_agent_fn: A function that takes as args a Tensorflow session and a
89Gym environment, and returns an agent.
90"""
91run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)92runner = create_runner_fn(FLAGS.base_dir, create_agent_fn,93schedule=FLAGS.schedule)94runner.run_experiment()95
96
97def main(unused_argv):98"""This main function acts as a wrapper around a gin-configurable experiment.99
100Args:
101unused_argv: Arguments (unused).
102"""
103tf.logging.set_verbosity(tf.logging.INFO)104launch_experiment(create_runner,105run_experience_replay_experiment.create_agent)106
107if __name__ == '__main__':108flags.mark_flag_as_required('base_dir')109app.run(main)110