23
from transformers import (
31
from typing import List, Literal
35
os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir)
38
from megatron.tokenizer import build_tokenizer
42
A script for converting saved NeoX Checkpoints to Huggingface (HF) compatible GPT-NeoX type models.
44
Note that this script does not support all NeoX features.
45
Please investigate carefully whether your model is compatible with all architectures supported by the GPTNeoXForCausalLM class in HF.
47
(e.g. position embeddings such as AliBi may not be supported by Huggingface's GPT-NeoX architecture).
56
"COLUMN_PARALLEL_LINEAR_KEYS": {
57
"mlp.dense_h_to_4h.weight": "mlp.dense_h_to_4h.weight",
58
"mlp.dense_h_to_4h.bias": "mlp.dense_h_to_4h.bias",
59
"attention.query_key_value.weight": "attention.query_key_value.weight",
60
"attention.query_key_value.bias": "attention.query_key_value.bias",
62
"ROW_PARALLEL_LINEAR_KEYS": {
63
"attention.dense.weight": "attention.dense.weight",
64
"mlp.dense_4h_to_h.weight": "mlp.dense_4h_to_h.weight",
66
"ROW_PARALLEL_BIAS_KEYS": {
67
"mlp.dense_4h_to_h.bias": "mlp.dense_4h_to_h.bias",
68
"attention.dense.bias": "attention.dense.bias",
71
"input_layernorm.weight": "input_layernorm.weight",
72
"input_layernorm.bias": "input_layernorm.bias",
73
"post_attention_layernorm.weight": "post_attention_layernorm.weight",
74
"post_attention_layernorm.bias": "post_attention_layernorm.bias",
77
"norm.weight": "weight",
82
"COLUMN_PARALLEL_LINEAR_KEYS": {
83
"mlp.w1.weight": "mlp.gate_proj.weight",
84
"mlp.w3.weight": "mlp.up_proj.weight",
86
"ROW_PARALLEL_LINEAR_KEYS": {
87
"attention.dense.weight": "self_attn.o_proj.weight",
88
"mlp.w2.weight": "mlp.down_proj.weight",
90
"ROW_PARALLEL_BIAS_KEYS": {},
92
"input_layernorm.scale": "input_layernorm.weight",
93
"post_attention_layernorm.scale": "post_attention_layernorm.weight",
96
"norm.scale": "weight",
99
"attention.query_key_value.weight": [
100
"self_attn.q_proj.weight",
101
"self_attn.k_proj.weight",
102
"self_attn.v_proj.weight",
108
MODEL_KEYS["mistral"] = MODEL_KEYS["llama"]
112
input_checkpoint_path: str, mp_partitions: int, layer_idx: int, sequential: bool
113
) -> List[torch.Tensor]:
114
"""Returns a list containing all states from a model (across MP partitions)"""
117
filename_format = f"mp_rank_{{i:02}}_model_states.pt"
119
filename_format = f"layer_{layer_idx:02}-model_{{i:02}}-model_states.pt"
124
input_checkpoint_path,
125
filename_format.format(i=i),
127
map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
129
for i in range(mp_partitions)
132
return loaded_tp_ranks
136
state_dicts: List[torch.Tensor], key: str, layer_idx: int, sequential: bool
138
"""Helper that returns a list containing a given weight's state from each MP partition, for a given layer in the model."""
142
key = f"sequential.{layer_idx}.{key}"
144
return [state_dict["module"][key] for state_dict in state_dicts]
150
return [state_dict[key] for state_dict in state_dicts]
153
def get_key(loaded_config, key, default=None):
155
Search for a given key in a NeoX yaml. normalizes underscores -> hyphens
157
key = key.replace("_", "-")
159
return loaded_config[key]
161
key = key.replace("-", "_")
163
return loaded_config[key]
168
def create_config(neox_config, architecture="neox"):
169
"""take in a loaded yaml from NeoX and assign relevant values to HF config.
170
Returns: GPTNeoXConfig() object
173
def gated_size(hidden_dim):
177
ff_dim = int(2 * hidden_dim * 4 / 3)
178
ff_dim = 256 * ((ff_dim + 256 - 1) // 256)
185
def __init__(self, neox_config):
186
self.make_vocab_size_divisible_by = get_key(
187
neox_config, "make-vocab-size-divisible-by", default=128
189
self.model_parallel_size = get_key(neox_config, "model-parallel-size")
190
self.vocab_file = get_key(neox_config, "vocab-file")
191
self.merge_file = get_key(neox_config, "merge-file")
192
self.tokenizer_type = get_key(neox_config, "tokenizer-type")
196
args = TokenizerArgs(neox_config)
197
tokenizer = build_tokenizer(args)
199
pad_token = tokenizer.pad
206
use_tied_lns = get_key(neox_config, "gpt-j-tied", False)
209
raise NotImplementedError(
210
"""ERROR: Huggingface Transformers does not yet support a single shared layernorm
211
per transformer block for GPT-NeoX models trained w/ GPT-J parallel residuals.
212
See https://github.com/EleutherAI/gpt-neox/pull/481 for further details."""
219
"vocab_size": args.padded_vocab_size,
220
"hidden_size": get_key(neox_config, "hidden-size"),
221
"num_hidden_layers": get_key(neox_config, "num-layers"),
222
"num_attention_heads": get_key(neox_config, "num-attention-heads"),
223
"max_position_embeddings": get_key(neox_config, "max-position-embeddings"),
224
"initializer_range": get_key(neox_config, "init-method-std", 0.02),
225
"tie_word_embeddings": (not get_key(neox_config, "no-weight-tying", False)),
228
if architecture == "mistral" or architecture == "llama":
231
"intermediate_size": get_key(
234
gated_size(get_key(neox_config, "hidden-size")),
236
"num_key_value_heads": get_key(
239
get_key(neox_config, "num-attention-heads"),
241
"hidden_act": get_key(neox_config, "activation", default="silu"),
242
"rms_norm_eps": get_key(neox_config, "rms-norm-epsilon", 1.0e-6),
243
"bos_token_id": tokenizer.eod,
244
"eos_token_id": tokenizer.eod,
245
"rope_theta": get_key(neox_config, "rotary-emb-base", 10000.0),
249
if architecture == "mistral":
253
"sliding_window": get_key(
254
neox_config, "sliding-window-width", 4096
258
hf_config = MistralConfig(**args)
259
elif architecture == "llama":
264
"attention_bias": get_key(
265
neox_config, "use_bias_in_attn_linear", True
269
hf_config = LlamaConfig(**args)
274
"rotary_pct": get_key(neox_config, "rotary-pct", default=1.0),
275
"rotary_emb_base": get_key(
276
neox_config, "rotary-emb-base", default=1000.0
278
"use_parallel_residual": get_key(neox_config, "gpt-j-residual", False),
279
"layer_norm_eps": get_key(neox_config, "layernorm-epsilon", 1e-5),
282
hf_config = GPTNeoXConfig(**args)
287
def reshard_and_split_qkv(
289
hf_config: AutoConfig,
290
loaded_tp_ranks: List[torch.Tensor],
295
A helper function which performs reshaping and sharding to make the QKV projection from NeoX compatible with HF Llama models,
296
even when grouped-query attention is required.
298
for key, hf_keys in param_mapping.items():
300
isinstance(hf_keys, list) and len(hf_keys) == 3
301
), "Must map QKV to precisely 3 resulting weight matrices."
303
for key, hf_keys in param_mapping.items():
305
sharded_qkv = torch.stack(
306
get_state(loaded_tp_ranks, key, layer_idx, sequential), dim=0
310
sharded_qkv = sharded_qkv.view(
311
len(loaded_tp_ranks),
312
hf_config.num_attention_heads // len(loaded_tp_ranks),
314
hf_config.hidden_size
315
// hf_config.num_attention_heads
318
+ 2 * hf_config.num_key_value_heads / hf_config.num_attention_heads
321
hf_config.hidden_size,
324
q, k, v = torch.split(
327
hf_config.hidden_size // hf_config.num_attention_heads,
329
(hf_config.num_key_value_heads / hf_config.num_attention_heads)
330
* hf_config.hidden_size
331
// hf_config.num_attention_heads
334
(hf_config.num_key_value_heads / hf_config.num_attention_heads)
335
* hf_config.hidden_size
336
// hf_config.num_attention_heads
347
q, k, v = q.squeeze(dim=2), k.squeeze(dim=2), v.squeeze(dim=2)
349
hf_config.num_attention_heads,
350
hf_config.hidden_size // hf_config.num_attention_heads,
351
hf_config.hidden_size,
352
).reshape(hf_config.hidden_size, hf_config.hidden_size)
354
hf_config.num_key_value_heads,
355
hf_config.hidden_size // hf_config.num_attention_heads,
356
hf_config.hidden_size,
358
hf_config.hidden_size
359
// hf_config.num_attention_heads
360
* hf_config.num_key_value_heads,
361
hf_config.hidden_size,
364
hf_config.num_key_value_heads,
365
hf_config.hidden_size // hf_config.num_attention_heads,
366
hf_config.hidden_size,
368
hf_config.hidden_size
369
// hf_config.num_attention_heads
370
* hf_config.num_key_value_heads,
371
hf_config.hidden_size,
376
for hf_key, proj in zip(hf_keys, [q, k, v]):
377
state_dict[hf_key] = proj.clone()
382
input_checkpoint_path,
384
output_checkpoint_path,
385
sequential: bool = True,
386
precision: Literal["auto", "fp16", "bf16", "fp32"] = "auto",
387
architecture: Literal["neox", "llama", "mistral"] = "neox",
389
"""convert a NeoX checkpoint to a HF model format.
390
should perform model-parallel merging correctly
391
but only supports features allowed by HF GPT-NeoX implementation (e.g. rotary embeddings)
394
ARCH = MODEL_KEYS[architecture]
396
hf_config = create_config(loaded_config, architecture=architecture)
398
hf_model = AutoModelForCausalLM.from_config(hf_config)
400
if architecture == "neox":
401
hf_transformer = hf_model.gpt_neox
403
hf_transformer = hf_model.model
405
if precision == "auto":
406
print("Auto-detecting precision to save model into...")
408
fp16 = get_key(loaded_config, "fp16")
415
print("Saving weights in fp16 precision...")
419
bf16 = get_key(loaded_config, "bf16")
421
hf_model.to(dtype=torch.bfloat16)
422
print("Saving weights in bf16 precision...")
424
hf_model.to(dtype=torch.float)
426
"Model not trained in fp16 / bf16 mixed precision, saving weights in fp32..."
430
"bf16": torch.bfloat16,
431
"fp16": torch.float16,
434
print(f"Saving model into specified {precision} precision...")
435
hf_model.to(dtype=name_to_dtype[precision])
437
mp_partitions = get_key(loaded_config, "model-parallel-size")
444
loaded_tp_ranks = load_partitions(
445
input_checkpoint_path, mp_partitions, layer_idx=0, sequential=sequential
450
if architecture == "neox":
451
embed_in = hf_transformer.embed_in
453
embed_in = hf_transformer.embed_tokens
454
embed_in.load_state_dict(
459
"word_embeddings.weight",
461
sequential=sequential,
468
hf_config.vocab_size == embed_in.weight.shape[0]
469
), f"ERROR: calculated vocab size {hf_config.vocab_size} != embed param size {embed_in.shape[0]}"
472
for layer_i in tqdm(range(get_key(loaded_config, "num-layers"))):
475
hf_layer = hf_transformer.layers[layer_i]
480
loaded_tp_ranks = load_partitions(
481
input_checkpoint_path,
483
layer_idx=layer_i + 2,
484
sequential=sequential,
489
for key, hf_key in ARCH["ROW_PARALLEL_LINEAR_KEYS"].items():
490
state_dict[hf_key] = torch.cat(
492
loaded_tp_ranks, key, layer_idx=layer_i + 2, sequential=sequential
498
for key, hf_key in ARCH["NORM_KEYS"].items():
499
state_dict[hf_key] = sum(
501
loaded_tp_ranks, key, layer_idx=layer_i + 2, sequential=sequential
503
) / len(loaded_tp_ranks)
506
for key, hf_key in ARCH["COLUMN_PARALLEL_LINEAR_KEYS"].items():
507
state_dict[hf_key] = torch.cat(
509
loaded_tp_ranks, key, layer_idx=layer_i + 2, sequential=sequential
515
for key, hf_key in ARCH["ROW_PARALLEL_BIAS_KEYS"].items():
516
state_dict[hf_key] = sum(
518
loaded_tp_ranks, key, layer_idx=layer_i + 2, sequential=sequential
523
if "attention.bias" in hf_layer.state_dict():
524
state_dict["attention.bias"] = hf_layer.state_dict()["attention.bias"]
525
if "attention.masked_bias" in hf_layer.state_dict():
526
state_dict["attention.masked_bias"] = hf_layer.state_dict()[
527
"attention.masked_bias"
533
if "GQA_QKV_KEYS" in ARCH:
535
reshard_and_split_qkv(
536
param_mapping=ARCH["GQA_QKV_KEYS"],
538
loaded_tp_ranks=loaded_tp_ranks,
539
layer_idx=layer_i + 2,
540
sequential=sequential,
544
hf_layer.load_state_dict(state_dict)
547
loaded_tp_ranks = load_partitions(
548
input_checkpoint_path,
550
get_key(loaded_config, "num-layers") + 3,
551
sequential=sequential,
554
if architecture == "neox":
555
lm_head = hf_model.embed_out
557
lm_head = hf_model.lm_head
559
for key, hf_key in ARCH["FINAL_NORM_KEYS"].items():
560
norm_state_dict[hf_key] = sum(
564
layer_idx=get_key(loaded_config, "num-layers") + 3,
565
sequential=sequential,
567
) / len(loaded_tp_ranks)
569
if architecture == "neox":
570
final_layer_norm = hf_transformer.final_layer_norm
572
final_layer_norm = hf_transformer.norm
574
final_layer_norm.load_state_dict(norm_state_dict)
578
loaded_tp_ranks = load_partitions(
579
input_checkpoint_path,
581
get_key(loaded_config, "num-layers") + 4,
582
sequential=sequential,
585
if architecture == "neox":
586
lm_head = hf_model.embed_out
588
lm_head = hf_model.lm_head
589
lm_head.load_state_dict(
594
"final_linear.weight",
595
layer_idx=get_key(loaded_config, "num-layers") + 4,
596
sequential=sequential,
608
def main(input_args=None, overwrite_values=None):
609
from huggingface_hub import create_repo, HfApi
611
parser = argparse.ArgumentParser(
612
description="Merge MP partitions and convert to HF Model."
617
help="Path to NeoX checkpoint, e.g. /path/to/model/global_step143000",
622
help="Path to config file for the input NeoX checkpoint.",
627
help="Output dir, where to save the HF Model, tokenizer, and configs",
633
help="What precision to save the model into. Defaults to auto, which auto-detects which 16-bit dtype to save into, or falls back to fp32.",
636
"--no_save_tokenizer",
638
help="Whether to skip saving the tokenizer alongside a model.",
644
help="What HF model class type to export into.",
646
args = parser.parse_args(input_args)
649
assert args.precision in [
654
], f"expected --precision to be one of 'auto', 'fp16', 'bf16', 'fp32' but got '{args.precision}' !"
655
assert args.architecture in [
659
], f"expected --architecture to be one of 'neox', 'mistral', 'llama', but got '{args.architecture}' !"
661
with open(args.config_file) as f:
662
loaded_config = yaml.full_load(f)
664
loaded_config.update(overwrite_values)
671
pipeline_world_size = get_key(loaded_config, "pipe-parallel-size", 1)
672
if pipeline_world_size == 0:
675
f"Detected 'pipe-parallel-size' of {pipeline_world_size}, assuming model is saved as Sequential..."
680
f"Detected 'pipe-parallel-size' of {pipeline_world_size}, assuming model is saved as PipelineModule..."
688
sequential=sequential,
689
architecture=args.architecture,
693
hf_model.save_pretrained(args.output_dir)
695
if not args.no_save_tokenizer:
697
tokenizer_type = get_key(loaded_config, "tokenizer-type")
699
if tokenizer_type == "HFTokenizer":
700
print(f"saving tokenizer from file {get_key(loaded_config, 'vocab-file')}")
702
"Warning: please check that your model config and tokenizer end with the correct special tokens (EOS, BOS)."
704
from transformers import PreTrainedTokenizerFast
706
tokenizer = PreTrainedTokenizerFast(
707
tokenizer_file=get_key(loaded_config, "vocab-file")
709
print("loaded tokenizer: ", tokenizer)
710
tokenizer.save_pretrained(args.output_dir)
711
print("tokenizer saved!")
714
if __name__ == "__main__":