google-research

Форк
0
/
learner_config.py 
83 строки · 3.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Learner config."""
17

18
import dataclasses
19
from typing import Optional
20

21

22
@dataclasses.dataclass
23
class LearnerConfig:
24
  """Config for the learner."""
25
  # Checkpoint save period in seconds.
26
  save_checkpoint_secs: int = 1800
27
  # Total iterations to train for.
28
  total_iterations: int = int(1e6)
29
  # Batch size for training.
30
  batch_size: int = 64
31
  # Whether actors block when enqueueing.
32
  replay_queue_block: int = 0
33
  # Batch size for the recurrent inference.
34
  recurrent_inference_batch_size: int = 32
35
  # Batch size for initial inference.
36
  initial_inference_batch_size: int = 4
37
  # Number of TPUs for training.
38
  num_training_tpus: int = 1
39
  # Path to the checkpoint used to initialize the agent.
40
  init_checkpoint: Optional[str] = None
41
  # Size of the replay queue.
42
  replay_buffer_size: int = 1000
43
  # Size of the replay queue.
44
  replay_queue_size: int = 100
45
  # After sampling an episode from the replay buffer, the corresponding priority
46
  # is set to this value. For a value < 1, no priority update will be done.
47
  replay_buffer_update_priority_after_sampling_value: float = 1e-6
48
  # Size of the replay buffer (in number of batches stored).
49
  flush_learner_log_every_n_s: int = 60
50
  # If true, logs are written to tensorboard.
51
  enable_learner_logging: bool = True
52
  # Log frequency in number of training steps.
53
  log_frequency: int = 100
54
  # Exponent used when computing the importance sampling correction. 0 means no
55
  # importance sampling correction. 1 means full importance sampling correction.
56
  importance_sampling_exponent: float = 0.0
57
  # For sampling from priority queue. 0 for uniform. The higher this value the
58
  # more likely it is to sample an instance for which the model predicts a wrong
59
  # value.
60
  priority_sampling_exponent: float = 0.0
61
  # How many batches the learner skips.
62
  learner_skip: int = 0
63
  # Save the agent in ExportAgent format.
64
  export_agent: bool = False
65
  # L2 penalty.
66
  weight_decay: float = 1e-5
67
  # Scaling for the policy loss term.
68
  policy_loss_scaling: float = 1.0
69
  # Scaling for the reward loss term.
70
  reward_loss_scaling: float = 1.0
71
  # Entropy loss for the policy loss term.
72
  policy_loss_entropy_regularizer: float = 0.0
73
  # Gradient norm clip (0 for no clip).
74
  gradient_norm_clip: float = 0.0
75
  # Enables debugging.
76
  debug: bool = False
77

78
  # The fields below are defined in seed_rl/common/common_flags.py
79

80
  # TensorFlow log directory.
81
  logdir: str = '/tmp/agent'
82
  # Server address.
83
  server_address: str = 'unix:/tmp/agent_grpc'
84

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.