tabby/scripts/huggingface_gptneox_convert.py

258 lines
8.9 KiB
Python
Raw Normal View History

2023-03-18 14:58:53 +00:00
import argparse
import configparser
import multiprocessing
import os
2023-03-20 08:51:28 +00:00
import shutil
2023-03-18 14:58:53 +00:00
import sys
from pathlib import Path
import numpy as np
import torch
2023-03-20 08:51:28 +00:00
from transformers import GPTNeoXForCausalLM
2023-03-18 14:58:53 +00:00
def get_weight_data_type(data_type):
if data_type == "fp32":
return np.float32
elif data_type == "fp16":
return np.float16
else:
assert False, f"Invalid weight data type {data_type}"
2023-03-20 08:51:28 +00:00
def split_and_convert_process(saved_dir, factor, key, args, config, val):
2023-03-18 14:58:53 +00:00
if (
key.find("input_layernorm.weight") != -1
or key.find("input_layernorm.bias") != -1
or key.find("attention.dense.bias") != -1
or key.find("post_attention_layernorm.weight") != -1
or key.find("post_attention_layernorm.bias") != -1
or key.find("mlp.dense_4h_to_h.bias") != -1
or key.find("final_layernorm.weight") != -1
or key.find("final_layernorm.bias") != -1
):
2023-03-20 08:51:28 +00:00
saved_path = saved_dir + f"/model.{key}.bin"
val.tofile(saved_path)
2023-03-18 14:58:53 +00:00
elif (
key.find("attention.dense.weight") != -1
or key.find("mlp.dense_4h_to_h.weight") != -1
):
split_vals = np.split(val, factor, axis=0)
for j in range(factor):
2023-03-20 08:51:28 +00:00
saved_path = saved_dir + f"/model.{key}.{j}.bin"
2023-03-18 14:58:53 +00:00
split_vals[j].tofile(saved_path)
elif (
key.find("mlp.dense_h_to_4h.weight") != -1
or key.find("mlp.dense_h_to_4h.bias") != -1
):
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
2023-03-20 08:51:28 +00:00
saved_path = saved_dir + f"/model.{key}.{j}.bin"
2023-03-18 14:58:53 +00:00
split_vals[j].tofile(saved_path)
elif key.find("attention.query_key_value.bias") != -1:
local_dim = (int)(val.shape[-1] / 3)
n_head = config["num_attention_heads"]
val = val.reshape(n_head, 3, local_dim // n_head)
val = np.transpose(val, [1, 0, 2]).reshape(3, local_dim)
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
2023-03-20 08:51:28 +00:00
saved_path = saved_dir + f"/model.{key}.{j}.bin"
2023-03-18 14:58:53 +00:00
split_vals[j].tofile(saved_path)
elif key.find("attention.query_key_value.weight") != -1:
hidden_dim = val.shape[0]
local_dim = (int)(val.shape[-1] / 3)
n_head = config["num_attention_heads"]
# Note that the HF qkv weight are stored as [hidden_size, num_heads, 3, head_hidden]
# FT needs the shape of [hidden_size, 3, num_heads, head_hidden]
val = val.reshape(hidden_dim, n_head, 3, local_dim // n_head)
val = np.transpose(val, [0, 2, 1, 3]).reshape(hidden_dim, 3, local_dim)
# print(np.mean(np.abs(val[:, 0, :])))
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
2023-03-20 08:51:28 +00:00
saved_path = saved_dir + f"/model.{key}.{j}.bin"
2023-03-18 14:58:53 +00:00
split_vals[j].tofile(saved_path)
else:
print("[ERROR] cannot find key '{}'".format(key))
def split_and_convert(args):
saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num
if os.path.exists(saved_dir) == False:
os.makedirs(saved_dir)
ckpt_name = args.in_file
t_gpu_num = args.trained_gpu_num
i_gpu_num = args.infer_gpu_num
assert i_gpu_num % t_gpu_num == 0
factor = (int)(i_gpu_num / t_gpu_num)
# load position_embedding from rank 0
# model = torch.load(ckpt_name)
model = GPTNeoXForCausalLM.from_pretrained(args.in_file)
hf_config = vars(model.config)
np_weight_data_type = get_weight_data_type(args.weight_data_type)
try:
model_name = args.model_name
2023-03-20 08:51:28 +00:00
n_heads = hf_config["num_attention_heads"]
head_size = hf_config["hidden_size"] // n_heads
rotary_dim = int(head_size * hf_config["rotary_pct"])
use_gptj_residual = int(hf_config["use_parallel_residual"])
2023-03-18 14:58:53 +00:00
config = configparser.ConfigParser()
config["gptneox"] = {}
config["gptneox"]["model_name"] = model_name
2023-03-20 08:51:28 +00:00
config["gptneox"]["head_num"] = str(n_heads)
config["gptneox"]["size_per_head"] = str(head_size)
config["gptneox"]["inter_size"] = str(hf_config["intermediate_size"])
2023-03-18 14:58:53 +00:00
config["gptneox"]["num_layer"] = str(hf_config["num_hidden_layers"])
config["gptneox"]["rotary_embedding"] = str(rotary_dim)
config["gptneox"]["vocab_size"] = str(hf_config["vocab_size"])
config["gptneox"]["start_id"] = str(hf_config["bos_token_id"])
config["gptneox"]["end_id"] = str(hf_config["eos_token_id"])
2023-03-20 08:51:28 +00:00
config["gptneox"]["use_gptj_residual"] = str(use_gptj_residual)
2023-03-18 14:58:53 +00:00
config["gptneox"]["weight_data_type"] = args.weight_data_type
with open((Path(saved_dir) / f"config.ini").as_posix(), "w") as configfile:
config.write(configfile)
except Exception as e:
print(f"Fail to save the config in config.ini.", e)
ft_model_name_pattern = [
"input_layernorm.bias",
"input_layernorm.weight",
"attention.query_key_value.bias",
"attention.query_key_value.weight",
"attention.dense.bias",
"attention.dense.weight",
"post_attention_layernorm.bias",
"post_attention_layernorm.weight",
"mlp.dense_h_to_4h.bias",
"mlp.dense_h_to_4h.weight",
"mlp.dense_4h_to_h.bias",
"mlp.dense_4h_to_h.weight",
]
torch.multiprocessing.set_start_method("spawn")
pool = multiprocessing.Pool(args.processes)
for name, param in model.named_parameters():
2023-03-20 08:51:28 +00:00
array = param.detach().cpu().numpy().astype(np_weight_data_type)
# print("input shape", name, array.shape)
2023-03-18 14:58:53 +00:00
if name.find("weight") == -1 and name.find("bias") == -1:
2023-03-20 08:51:28 +00:00
print("skipped", name)
2023-03-18 14:58:53 +00:00
continue
elif name == "gpt_neox.embed_in.weight":
2023-03-20 08:51:28 +00:00
array.tofile(saved_dir + "model.wte.bin")
2023-03-18 14:58:53 +00:00
elif name == "gpt_neox.final_layer_norm.bias":
2023-03-20 08:51:28 +00:00
array.tofile(saved_dir + "model.final_layernorm.bias.bin")
2023-03-18 14:58:53 +00:00
elif name == "gpt_neox.final_layer_norm.weight":
2023-03-20 08:51:28 +00:00
array.tofile(saved_dir + "model.final_layernorm.weight.bin")
2023-03-18 14:58:53 +00:00
elif name == "embed_out.weight":
2023-03-20 08:51:28 +00:00
array.tofile(saved_dir + "model.lm_head.weight.bin")
2023-03-18 14:58:53 +00:00
else:
processed = False
for i in range(len(ft_model_name_pattern)):
if name.find(ft_model_name_pattern[i]) != -1:
new_name = name.replace("gpt_neox.", "")
pool.starmap(
split_and_convert_process,
[
(
saved_dir,
factor,
new_name,
args,
vars(model.config),
2023-03-20 08:51:28 +00:00
array.T,
2023-03-18 14:58:53 +00:00
)
],
)
processed = True
break
if not processed:
2023-03-20 08:51:28 +00:00
print("Unused layer", name)
2023-03-18 14:58:53 +00:00
pool.close()
pool.join()
# Post-process biases if use_gptj_residual is True
2023-03-20 08:51:28 +00:00
if use_gptj_residual:
for layer_idx in range(hf_config["num_hidden_layers"]):
2023-03-18 14:58:53 +00:00
attn_bias = np.fromfile(
saved_dir + f"/model.layers.{layer_idx}.attention.dense.bias.bin",
dtype=np_weight_data_type,
2023-03-18 14:58:53 +00:00
)
mlp_bias = np.fromfile(
saved_dir + f"/model.layers.{layer_idx}.mlp.dense_4h_to_h.bias.bin",
dtype=np_weight_data_type,
2023-03-18 14:58:53 +00:00
)
(attn_bias + mlp_bias).astype(np_weight_data_type).tofile(
2023-03-18 14:58:53 +00:00
saved_dir + f"/model.layers.{layer_idx}.mlp.attention.bias.sum.bin"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument(
"-saved_dir", "-o", type=str, help="file name of output file", required=True
)
parser.add_argument(
"-in_file",
"-i",
type=str,
help="file name of input checkpoint file",
required=True,
)
parser.add_argument(
"-trained_gpu_num",
"-t_g",
type=int,
help="How many gpus for inference",
default=1,
)
parser.add_argument(
"-infer_gpu_num",
"-i_g",
type=int,
help="How many gpus for inference",
required=True,
)
parser.add_argument(
"-processes",
"-p",
type=int,
help="How many processes to spawn for conversion (default: 4)",
default=4,
)
parser.add_argument(
"-weight_data_type", type=str, default="fp32", choices=["fp32", "fp16"]
)
parser.add_argument(
"-model_name", "-m_n", type=str, help="model name", required=True
)
args = parser.parse_args()
print("\n=============== Argument ===============")
for key in vars(args):
print("{}: {}".format(key, vars(args)[key]))
print("========================================")
2023-03-20 08:51:28 +00:00
shutil.rmtree(args.saved_dir, ignore_errors=True)
2023-03-18 14:58:53 +00:00
split_and_convert(args)