google-research
117 строк · 4.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Module defining classes and helper methods for general agents."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from bonus_based_exploration.intrinsic_motivation import intrinsic_dqn_agent
23from bonus_based_exploration.intrinsic_motivation import intrinsic_rainbow_agent
24from bonus_based_exploration.noisy_networks import noisy_dqn_agent
25from bonus_based_exploration.noisy_networks import noisy_rainbow_agent
26from dopamine.discrete_domains import run_experiment
27import gin
28
29
30@gin.configurable
31def create_exploration_agent(sess, environment, agent_name=None,
32summary_writer=None, debug_mode=False):
33"""Creates an exploration agent.
34
35Args:
36sess: A `tf.Session` object for running associated ops.
37environment: A gym environment (e.g. Atari 2600).
38agent_name: str, name of the agent to create. Agent supported are dqn_cts
39and rainbow_cts.
40summary_writer: A Tensorflow summary writer to pass to the agent
41for in-agent training statistics in Tensorboard.
42debug_mode: bool, whether to output Tensorboard summaries. If set to true,
43the agent will output in-episode statistics to Tensorboard. Disabled by
44default as this results in slower training.
45
46Returns:
47agent: An RL agent.
48
49Raises:
50ValueError: If `agent_name` is not in supported list.
51"""
52assert agent_name is not None
53if not debug_mode:
54summary_writer = None
55if agent_name == 'dqn_cts':
56return intrinsic_dqn_agent.CTSDQNAgent(
57sess, num_actions=environment.action_space.n,
58summary_writer=summary_writer)
59elif agent_name == 'rainbow_cts':
60return intrinsic_rainbow_agent.CTSRainbowAgent(
61sess, num_actions=environment.action_space.n,
62summary_writer=summary_writer)
63if agent_name == 'dqn_pixelcnn':
64return intrinsic_dqn_agent.PixelCNNDQNAgent(
65sess, num_actions=environment.action_space.n,
66summary_writer=summary_writer)
67elif agent_name == 'rainbow_pixelcnn':
68return intrinsic_rainbow_agent.PixelCNNRainbowAgent(
69sess, num_actions=environment.action_space.n,
70summary_writer=summary_writer)
71elif agent_name == 'dqn_rnd':
72return intrinsic_dqn_agent.RNDDQNAgent(
73sess, num_actions=environment.action_space.n,
74summary_writer=summary_writer)
75elif agent_name == 'rainbow_rnd':
76return intrinsic_rainbow_agent.RNDRainbowAgent(
77sess, num_actions=environment.action_space.n,
78summary_writer=summary_writer)
79elif agent_name == 'noisy_dqn':
80return noisy_dqn_agent.NoisyDQNAgent(
81sess, num_actions=environment.action_space.n,
82summary_writer=summary_writer)
83elif agent_name == 'noisy_rainbow':
84return noisy_rainbow_agent.NoisyRainbowAgent(
85sess, num_actions=environment.action_space.n,
86summary_writer=summary_writer)
87else:
88return run_experiment.create_agent(sess, environment, agent_name,
89summary_writer, debug_mode)
90
91
92@gin.configurable
93def create_exploration_runner(base_dir, create_agent_fn,
94schedule='continuous_train_and_eval'):
95"""Creates an experiment Runner.
96
97Args:
98base_dir: Base directory for hosting all subdirectories.
99create_agent_fn: A function that takes as args a Tensorflow session and a
100Gym Atari 2600 environment, and returns an agent.
101schedule: string, which type of Runner to use.
102
103Returns:
104runner: A `run_experiment.Runner` like object.
105
106Raises:
107ValueError: When an unknown schedule is encountered.
108"""
109assert base_dir is not None
110# Continuously runs training and eval till max num_iterations is hit.
111if schedule == 'continuous_train_and_eval':
112return run_experiment.Runner(base_dir, create_agent_fn)
113# Continuously runs training till maximum num_iterations is hit.
114elif schedule == 'continuous_train':
115return run_experiment.TrainRunner(base_dir, create_agent_fn)
116else:
117raise ValueError('Unknown schedule: {}'.format(schedule))
118