18
"""Megatron initialization."""
26
from megatron import fused_kernels
27
from megatron import mpu
28
from megatron.mpu import set_model_parallel_rank, set_model_parallel_world_size
34
def initialize_megatron(neox_args, allow_no_cuda=False):
35
"""Set initialize distributed and set autoresume and random seeds.
36
`allow_no_cuda` should not be set unless using megatron for cpu only
37
data processing. In general this arg should not be set unless you know
39
Returns a function to finalize distributed env initialization
40
(optionally, only when args.lazy_mpu_init == True)
44
assert torch.cuda.is_available(), "Megatron requires CUDA."
47
def finish_mpu_init():
49
_initialize_distributed(neox_args=neox_args)
52
if neox_args.rank == 0:
53
print("> setting random seeds to {} ...".format(neox_args.seed))
54
_set_random_seed(neox_args.seed)
58
neox_args.scaled_upper_triang_masked_softmax_fusion
59
or neox_args.scaled_masked_softmax_fusion
60
or neox_args.rope_fusion
62
fused_kernels.load_fused_kernels()
64
if neox_args.lazy_mpu_init:
65
neox_args.use_cpu_initialization = True
68
set_model_parallel_world_size(neox_args.model_parallel_size)
70
set_model_parallel_rank(neox_args.rank)
71
return finish_mpu_init
77
if neox_args.local_rank == 0:
78
from megatron.data.data_utils import compile_helper
83
_write_args_to_tensorboard(neox_args=neox_args)
88
def setup_deepspeed_random_and_activation_checkpointing(neox_args):
89
"""Optional DeepSpeed Activation Checkpointing features.
90
Gives access to partition activations, contiguous memory optimizations
91
and cpu checkpointing.
93
Activation checkpoint requires keep track of the random states
94
and setting the random seed for each MP process. Megatron uses
95
mpu.get_cuda_rng_tracker and mpu.model_parallel_cuda_manual_seed
96
for keeping track of the random states and setting the random seeds.
97
Since they are used in places outside of activation checkpointing,
98
we overwrite them to maintain consistency.
100
This must be called before all the calls to mpu.model_parallel_cuda_manual_seed
102
num_layers = neox_args.num_layers // neox_args.checkpoint_num_layers
105
if neox_args.num_layers % neox_args.checkpoint_num_layers == 0
109
deepspeed.checkpointing.configure(
111
partition_activations=neox_args.partition_activations,
112
contiguous_checkpointing=neox_args.contiguous_checkpointing,
113
num_checkpoints=num_layers,
114
checkpoint_in_cpu=neox_args.checkpoint_in_cpu,
115
synchronize=neox_args.synchronize_each_layer,
116
profile=neox_args.profile_backward,
120
def _initialize_distributed(neox_args):
121
"""Initialize torch.distributed and mpu."""
123
device_count = torch.cuda.device_count()
124
if torch.distributed.is_initialized():
126
if neox_args.rank == 0:
128
"torch distributed is already initialized, "
129
"skipping initialization ...",
132
neox_args.rank = torch.distributed.get_rank()
133
neox_args.world_size = torch.distributed.get_world_size()
137
if neox_args.rank == 0:
138
print("> initializing torch distributed ...", flush=True)
141
device = neox_args.rank % device_count
142
if neox_args.local_rank is not None:
144
neox_args.local_rank == device
145
), "expected local-rank to be the same as rank % device-count."
147
neox_args.local_rank = device
148
torch.cuda.set_device(device)
150
deepspeed.init_distributed(
151
dist_backend=neox_args.distributed_backend,
152
auto_mpi_discovery=True,
153
distributed_port=os.getenv("MASTER_PORT", "6000"),
158
pp = neox_args.pipe_parallel_size if neox_args.pipe_parallel_size >= 1 else 1
159
mp = neox_args.model_parallel_size if neox_args.model_parallel_size >= 1 else 1
161
neox_args.world_size % (pp * mp) == 0
162
), f"world_size={neox_args.world_size}, pp={pp}, mp={mp}"
163
dp = neox_args.world_size // (pp * mp)
165
from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
169
topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)
173
stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe
174
if 0 < stage_id < topo.get_dim("pipe") - 1:
175
offset = neox_args.seed + 1138
176
neox_args.seed = offset + (stage_id * mp)
180
if mpu.model_parallel_is_initialized():
182
"_initialize_distributed() model parallel is already initialized",
186
mpu.initialize_model_parallel(
187
neox_args.model_parallel_size,
189
fp32_allreduce=neox_args.fp32_allreduce,
193
setup_deepspeed_random_and_activation_checkpointing(neox_args=neox_args)
196
def _init_autoresume(neox_args):
197
"""Set autoresume start time."""
199
if neox_args.adlr_autoresume:
200
print_rank_0("> enabling autoresume ...")
201
sys.path.append(os.environ.get("SUBMIT_SCRIPTS", "."))
203
from userlib.auto_resume import AutoResume
204
except BaseException:
205
print("> ADLR autoresume is not available, exiting ...", flush=True)
207
neox_args.adlr_autoresume_object = AutoResume
209
if neox_args.adlr_autoresume_object:
210
torch.distributed.barrier()
211
neox_args.adlr_autoresume_object.init()
212
torch.distributed.barrier()
215
def _set_random_seed(seed):
216
"""Set random seed for reproducibility."""
217
if seed is not None and seed > 0:
220
torch.manual_seed(seed)
221
if torch.cuda.device_count() > 0:
222
mpu.model_parallel_cuda_manual_seed(seed)
224
raise ValueError("Seed ({}) should be a positive integer.".format(seed))
227
def _write_args_to_tensorboard(neox_args):
229
"""Write arguments to tensorboard."""
230
if neox_args.tensorboard_writer:
231
for arg_name in vars(neox_args):
232
neox_args.tensorboard_writer.add_text(
233
arg_name, str(getattr(neox_args, arg_name))