google-research

Форк
0
117 строк · 4.6 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Module defining classes and helper methods for general agents."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
from bonus_based_exploration.intrinsic_motivation import intrinsic_dqn_agent
23
from bonus_based_exploration.intrinsic_motivation import intrinsic_rainbow_agent
24
from bonus_based_exploration.noisy_networks import noisy_dqn_agent
25
from bonus_based_exploration.noisy_networks import noisy_rainbow_agent
26
from dopamine.discrete_domains import run_experiment
27
import gin
28

29

30
@gin.configurable
31
def create_exploration_agent(sess, environment, agent_name=None,
32
                             summary_writer=None, debug_mode=False):
33
  """Creates an exploration agent.
34

35
  Args:
36
    sess: A `tf.Session` object for running associated ops.
37
    environment: A gym environment (e.g. Atari 2600).
38
    agent_name: str, name of the agent to create. Agent supported are dqn_cts
39
      and rainbow_cts.
40
    summary_writer: A Tensorflow summary writer to pass to the agent
41
      for in-agent training statistics in Tensorboard.
42
    debug_mode: bool, whether to output Tensorboard summaries. If set to true,
43
      the agent will output in-episode statistics to Tensorboard. Disabled by
44
      default as this results in slower training.
45

46
  Returns:
47
    agent: An RL agent.
48

49
  Raises:
50
    ValueError: If `agent_name` is not in supported list.
51
  """
52
  assert agent_name is not None
53
  if not debug_mode:
54
    summary_writer = None
55
  if agent_name == 'dqn_cts':
56
    return intrinsic_dqn_agent.CTSDQNAgent(
57
        sess, num_actions=environment.action_space.n,
58
        summary_writer=summary_writer)
59
  elif agent_name == 'rainbow_cts':
60
    return intrinsic_rainbow_agent.CTSRainbowAgent(
61
        sess, num_actions=environment.action_space.n,
62
        summary_writer=summary_writer)
63
  if agent_name == 'dqn_pixelcnn':
64
    return intrinsic_dqn_agent.PixelCNNDQNAgent(
65
        sess, num_actions=environment.action_space.n,
66
        summary_writer=summary_writer)
67
  elif agent_name == 'rainbow_pixelcnn':
68
    return intrinsic_rainbow_agent.PixelCNNRainbowAgent(
69
        sess, num_actions=environment.action_space.n,
70
        summary_writer=summary_writer)
71
  elif agent_name == 'dqn_rnd':
72
    return intrinsic_dqn_agent.RNDDQNAgent(
73
        sess, num_actions=environment.action_space.n,
74
        summary_writer=summary_writer)
75
  elif agent_name == 'rainbow_rnd':
76
    return intrinsic_rainbow_agent.RNDRainbowAgent(
77
        sess, num_actions=environment.action_space.n,
78
        summary_writer=summary_writer)
79
  elif agent_name == 'noisy_dqn':
80
    return noisy_dqn_agent.NoisyDQNAgent(
81
        sess, num_actions=environment.action_space.n,
82
        summary_writer=summary_writer)
83
  elif agent_name == 'noisy_rainbow':
84
    return noisy_rainbow_agent.NoisyRainbowAgent(
85
        sess, num_actions=environment.action_space.n,
86
        summary_writer=summary_writer)
87
  else:
88
    return run_experiment.create_agent(sess, environment, agent_name,
89
                                       summary_writer, debug_mode)
90

91

92
@gin.configurable
93
def create_exploration_runner(base_dir, create_agent_fn,
94
                              schedule='continuous_train_and_eval'):
95
  """Creates an experiment Runner.
96

97
  Args:
98
    base_dir: Base directory for hosting all subdirectories.
99
    create_agent_fn: A function that takes as args a Tensorflow session and a
100
     Gym Atari 2600 environment, and returns an agent.
101
    schedule: string, which type of Runner to use.
102

103
  Returns:
104
    runner: A `run_experiment.Runner` like object.
105

106
  Raises:
107
    ValueError: When an unknown schedule is encountered.
108
  """
109
  assert base_dir is not None
110
  # Continuously runs training and eval till max num_iterations is hit.
111
  if schedule == 'continuous_train_and_eval':
112
    return run_experiment.Runner(base_dir, create_agent_fn)
113
  # Continuously runs training till maximum num_iterations is hit.
114
  elif schedule == 'continuous_train':
115
    return run_experiment.TrainRunner(base_dir, create_agent_fn)
116
  else:
117
    raise ValueError('Unknown schedule: {}'.format(schedule))
118

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.