google-research
83 строки · 3.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Learner config."""
17
18import dataclasses
19from typing import Optional
20
21
22@dataclasses.dataclass
23class LearnerConfig:
24"""Config for the learner."""
25# Checkpoint save period in seconds.
26save_checkpoint_secs: int = 1800
27# Total iterations to train for.
28total_iterations: int = int(1e6)
29# Batch size for training.
30batch_size: int = 64
31# Whether actors block when enqueueing.
32replay_queue_block: int = 0
33# Batch size for the recurrent inference.
34recurrent_inference_batch_size: int = 32
35# Batch size for initial inference.
36initial_inference_batch_size: int = 4
37# Number of TPUs for training.
38num_training_tpus: int = 1
39# Path to the checkpoint used to initialize the agent.
40init_checkpoint: Optional[str] = None
41# Size of the replay queue.
42replay_buffer_size: int = 1000
43# Size of the replay queue.
44replay_queue_size: int = 100
45# After sampling an episode from the replay buffer, the corresponding priority
46# is set to this value. For a value < 1, no priority update will be done.
47replay_buffer_update_priority_after_sampling_value: float = 1e-6
48# Size of the replay buffer (in number of batches stored).
49flush_learner_log_every_n_s: int = 60
50# If true, logs are written to tensorboard.
51enable_learner_logging: bool = True
52# Log frequency in number of training steps.
53log_frequency: int = 100
54# Exponent used when computing the importance sampling correction. 0 means no
55# importance sampling correction. 1 means full importance sampling correction.
56importance_sampling_exponent: float = 0.0
57# For sampling from priority queue. 0 for uniform. The higher this value the
58# more likely it is to sample an instance for which the model predicts a wrong
59# value.
60priority_sampling_exponent: float = 0.0
61# How many batches the learner skips.
62learner_skip: int = 0
63# Save the agent in ExportAgent format.
64export_agent: bool = False
65# L2 penalty.
66weight_decay: float = 1e-5
67# Scaling for the policy loss term.
68policy_loss_scaling: float = 1.0
69# Scaling for the reward loss term.
70reward_loss_scaling: float = 1.0
71# Entropy loss for the policy loss term.
72policy_loss_entropy_regularizer: float = 0.0
73# Gradient norm clip (0 for no clip).
74gradient_norm_clip: float = 0.0
75# Enables debugging.
76debug: bool = False
77
78# The fields below are defined in seed_rl/common/common_flags.py
79
80# TensorFlow log directory.
81logdir: str = '/tmp/agent'
82# Server address.
83server_address: str = 'unix:/tmp/agent_grpc'
84