258 lines
8.9 KiB
Python
258 lines
8.9 KiB
Python
|
|
import argparse
|
||
|
|
import configparser
|
||
|
|
import multiprocessing
|
||
|
|
import os
|
||
|
|
import shutil
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
from transformers import GPTNeoXForCausalLM
|
||
|
|
|
||
|
|
|
||
|
|
def get_weight_data_type(data_type):
|
||
|
|
if data_type == "fp32":
|
||
|
|
return np.float32
|
||
|
|
elif data_type == "fp16":
|
||
|
|
return np.float16
|
||
|
|
else:
|
||
|
|
assert False, f"Invalid weight data type {data_type}"
|
||
|
|
|
||
|
|
|
||
|
|
def split_and_convert_process(saved_dir, factor, key, args, config, val):
|
||
|
|
|
||
|
|
if (
|
||
|
|
key.find("input_layernorm.weight") != -1
|
||
|
|
or key.find("input_layernorm.bias") != -1
|
||
|
|
or key.find("attention.dense.bias") != -1
|
||
|
|
or key.find("post_attention_layernorm.weight") != -1
|
||
|
|
or key.find("post_attention_layernorm.bias") != -1
|
||
|
|
or key.find("mlp.dense_4h_to_h.bias") != -1
|
||
|
|
or key.find("final_layernorm.weight") != -1
|
||
|
|
or key.find("final_layernorm.bias") != -1
|
||
|
|
):
|
||
|
|
saved_path = saved_dir + f"/model.{key}.bin"
|
||
|
|
val.tofile(saved_path)
|
||
|
|
|
||
|
|
elif (
|
||
|
|
key.find("attention.dense.weight") != -1
|
||
|
|
or key.find("mlp.dense_4h_to_h.weight") != -1
|
||
|
|
):
|
||
|
|
split_vals = np.split(val, factor, axis=0)
|
||
|
|
for j in range(factor):
|
||
|
|
saved_path = saved_dir + f"/model.{key}.{j}.bin"
|
||
|
|
split_vals[j].tofile(saved_path)
|
||
|
|
|
||
|
|
elif (
|
||
|
|
key.find("mlp.dense_h_to_4h.weight") != -1
|
||
|
|
or key.find("mlp.dense_h_to_4h.bias") != -1
|
||
|
|
):
|
||
|
|
|
||
|
|
split_vals = np.split(val, factor, axis=-1)
|
||
|
|
for j in range(factor):
|
||
|
|
saved_path = saved_dir + f"/model.{key}.{j}.bin"
|
||
|
|
split_vals[j].tofile(saved_path)
|
||
|
|
|
||
|
|
elif key.find("attention.query_key_value.bias") != -1:
|
||
|
|
local_dim = (int)(val.shape[-1] / 3)
|
||
|
|
n_head = config["num_attention_heads"]
|
||
|
|
|
||
|
|
val = val.reshape(n_head, 3, local_dim // n_head)
|
||
|
|
val = np.transpose(val, [1, 0, 2]).reshape(3, local_dim)
|
||
|
|
split_vals = np.split(val, factor, axis=-1)
|
||
|
|
|
||
|
|
for j in range(factor):
|
||
|
|
saved_path = saved_dir + f"/model.{key}.{j}.bin"
|
||
|
|
split_vals[j].tofile(saved_path)
|
||
|
|
|
||
|
|
elif key.find("attention.query_key_value.weight") != -1:
|
||
|
|
hidden_dim = val.shape[0]
|
||
|
|
local_dim = (int)(val.shape[-1] / 3)
|
||
|
|
n_head = config["num_attention_heads"]
|
||
|
|
# Note that the HF qkv weight are stored as [hidden_size, num_heads, 3, head_hidden]
|
||
|
|
# FT needs the shape of [hidden_size, 3, num_heads, head_hidden]
|
||
|
|
val = val.reshape(hidden_dim, n_head, 3, local_dim // n_head)
|
||
|
|
val = np.transpose(val, [0, 2, 1, 3]).reshape(hidden_dim, 3, local_dim)
|
||
|
|
|
||
|
|
# print(np.mean(np.abs(val[:, 0, :])))
|
||
|
|
split_vals = np.split(val, factor, axis=-1)
|
||
|
|
|
||
|
|
for j in range(factor):
|
||
|
|
saved_path = saved_dir + f"/model.{key}.{j}.bin"
|
||
|
|
split_vals[j].tofile(saved_path)
|
||
|
|
|
||
|
|
else:
|
||
|
|
print("[ERROR] cannot find key '{}'".format(key))
|
||
|
|
|
||
|
|
|
||
|
|
def split_and_convert(args):
|
||
|
|
saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num
|
||
|
|
|
||
|
|
if os.path.exists(saved_dir) == False:
|
||
|
|
os.makedirs(saved_dir)
|
||
|
|
ckpt_name = args.in_file
|
||
|
|
|
||
|
|
t_gpu_num = args.trained_gpu_num
|
||
|
|
i_gpu_num = args.infer_gpu_num
|
||
|
|
assert i_gpu_num % t_gpu_num == 0
|
||
|
|
|
||
|
|
factor = (int)(i_gpu_num / t_gpu_num)
|
||
|
|
|
||
|
|
# load position_embedding from rank 0
|
||
|
|
# model = torch.load(ckpt_name)
|
||
|
|
model = GPTNeoXForCausalLM.from_pretrained(args.in_file)
|
||
|
|
hf_config = vars(model.config)
|
||
|
|
|
||
|
|
np_weight_data_type = get_weight_data_type(args.weight_data_type)
|
||
|
|
|
||
|
|
try:
|
||
|
|
model_name = args.model_name
|
||
|
|
n_heads = hf_config["num_attention_heads"]
|
||
|
|
head_size = hf_config["hidden_size"] // n_heads
|
||
|
|
rotary_dim = int(head_size * hf_config["rotary_pct"])
|
||
|
|
use_gptj_residual = int(hf_config["use_parallel_residual"])
|
||
|
|
|
||
|
|
config = configparser.ConfigParser()
|
||
|
|
config["gptneox"] = {}
|
||
|
|
config["gptneox"]["model_name"] = model_name
|
||
|
|
config["gptneox"]["head_num"] = str(n_heads)
|
||
|
|
config["gptneox"]["size_per_head"] = str(head_size)
|
||
|
|
config["gptneox"]["inter_size"] = str(hf_config["intermediate_size"])
|
||
|
|
config["gptneox"]["num_layer"] = str(hf_config["num_hidden_layers"])
|
||
|
|
config["gptneox"]["rotary_embedding"] = str(rotary_dim)
|
||
|
|
config["gptneox"]["vocab_size"] = str(hf_config["vocab_size"])
|
||
|
|
config["gptneox"]["start_id"] = str(hf_config["bos_token_id"])
|
||
|
|
config["gptneox"]["end_id"] = str(hf_config["eos_token_id"])
|
||
|
|
config["gptneox"]["use_gptj_residual"] = str(use_gptj_residual)
|
||
|
|
config["gptneox"]["weight_data_type"] = args.weight_data_type
|
||
|
|
|
||
|
|
with open((Path(saved_dir) / f"config.ini").as_posix(), "w") as configfile:
|
||
|
|
config.write(configfile)
|
||
|
|
except Exception as e:
|
||
|
|
print(f"Fail to save the config in config.ini.", e)
|
||
|
|
|
||
|
|
ft_model_name_pattern = [
|
||
|
|
"input_layernorm.bias",
|
||
|
|
"input_layernorm.weight",
|
||
|
|
"attention.query_key_value.bias",
|
||
|
|
"attention.query_key_value.weight",
|
||
|
|
"attention.dense.bias",
|
||
|
|
"attention.dense.weight",
|
||
|
|
"post_attention_layernorm.bias",
|
||
|
|
"post_attention_layernorm.weight",
|
||
|
|
"mlp.dense_h_to_4h.bias",
|
||
|
|
"mlp.dense_h_to_4h.weight",
|
||
|
|
"mlp.dense_4h_to_h.bias",
|
||
|
|
"mlp.dense_4h_to_h.weight",
|
||
|
|
]
|
||
|
|
|
||
|
|
torch.multiprocessing.set_start_method("spawn")
|
||
|
|
pool = multiprocessing.Pool(args.processes)
|
||
|
|
for name, param in model.named_parameters():
|
||
|
|
array = param.detach().cpu().numpy().astype(np_weight_data_type)
|
||
|
|
# print("input shape", name, array.shape)
|
||
|
|
if name.find("weight") == -1 and name.find("bias") == -1:
|
||
|
|
print("skipped", name)
|
||
|
|
continue
|
||
|
|
elif name == "gpt_neox.embed_in.weight":
|
||
|
|
array.tofile(saved_dir + "model.wte.bin")
|
||
|
|
elif name == "gpt_neox.final_layer_norm.bias":
|
||
|
|
array.tofile(saved_dir + "model.final_layernorm.bias.bin")
|
||
|
|
elif name == "gpt_neox.final_layer_norm.weight":
|
||
|
|
array.tofile(saved_dir + "model.final_layernorm.weight.bin")
|
||
|
|
elif name == "embed_out.weight":
|
||
|
|
array.tofile(saved_dir + "model.lm_head.weight.bin")
|
||
|
|
else:
|
||
|
|
processed = False
|
||
|
|
for i in range(len(ft_model_name_pattern)):
|
||
|
|
if name.find(ft_model_name_pattern[i]) != -1:
|
||
|
|
new_name = name.replace("gpt_neox.", "")
|
||
|
|
pool.starmap(
|
||
|
|
split_and_convert_process,
|
||
|
|
[
|
||
|
|
(
|
||
|
|
saved_dir,
|
||
|
|
factor,
|
||
|
|
new_name,
|
||
|
|
args,
|
||
|
|
vars(model.config),
|
||
|
|
array.T,
|
||
|
|
)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
processed = True
|
||
|
|
break
|
||
|
|
|
||
|
|
if not processed:
|
||
|
|
print("Unused layer", name)
|
||
|
|
|
||
|
|
pool.close()
|
||
|
|
pool.join()
|
||
|
|
|
||
|
|
# Post-process biases if use_gptj_residual is True
|
||
|
|
if use_gptj_residual:
|
||
|
|
for layer_idx in range(hf_config["num_hidden_layers"]):
|
||
|
|
attn_bias = np.fromfile(
|
||
|
|
saved_dir + f"/model.layers.{layer_idx}.attention.dense.bias.bin",
|
||
|
|
dtype=np_weight_data_type,
|
||
|
|
)
|
||
|
|
mlp_bias = np.fromfile(
|
||
|
|
saved_dir + f"/model.layers.{layer_idx}.mlp.dense_4h_to_h.bias.bin",
|
||
|
|
dtype=np_weight_data_type,
|
||
|
|
)
|
||
|
|
|
||
|
|
(attn_bias + mlp_bias).astype(np_weight_data_type).tofile(
|
||
|
|
saved_dir + f"/model.layers.{layer_idx}.mlp.attention.bias.sum.bin"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
|
||
|
|
parser.add_argument(
|
||
|
|
"-saved_dir", "-o", type=str, help="file name of output file", required=True
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"-in_file",
|
||
|
|
"-i",
|
||
|
|
type=str,
|
||
|
|
help="file name of input checkpoint file",
|
||
|
|
required=True,
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"-trained_gpu_num",
|
||
|
|
"-t_g",
|
||
|
|
type=int,
|
||
|
|
help="How many gpus for inference",
|
||
|
|
default=1,
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"-infer_gpu_num",
|
||
|
|
"-i_g",
|
||
|
|
type=int,
|
||
|
|
help="How many gpus for inference",
|
||
|
|
required=True,
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"-processes",
|
||
|
|
"-p",
|
||
|
|
type=int,
|
||
|
|
help="How many processes to spawn for conversion (default: 4)",
|
||
|
|
default=4,
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]
|
||
|
|
)
|
||
|
|
parser.add_argument(
|
||
|
|
"-model_name", "-m_n", type=str, help="model name", required=True
|
||
|
|
)
|
||
|
|
|
||
|
|
args = parser.parse_args()
|
||
|
|
print("\n=============== Argument ===============")
|
||
|
|
for key in vars(args):
|
||
|
|
print("{}: {}".format(key, vars(args)[key]))
|
||
|
|
print("========================================")
|
||
|
|
|
||
|
|
shutil.rmtree(args.saved_dir, ignore_errors=True)
|
||
|
|
split_and_convert(args)
|