update huggingface_gptneox_convert.py (#112)

* update huggingface_gptneox_convert.py

* fix format

* fix pre-commit
add-tracing
ADLIBS 2023-04-23 21:36:46 +08:00 committed by GitHub
parent 300fe8c2b0
commit 1b96c18ab8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 64 additions and 63 deletions

View File

@ -2,7 +2,6 @@ import argparse
import configparser
import multiprocessing
import os
import shutil
from pathlib import Path
import numpy as np
@ -24,10 +23,8 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
if (
key.find("input_layernorm.weight") != -1
or key.find("input_layernorm.bias") != -1
or key.find("attention.dense.bias") != -1
or key.find("post_attention_layernorm.weight") != -1
or key.find("post_attention_layernorm.bias") != -1
or key.find("mlp.dense_4h_to_h.bias") != -1
or key.find("final_layernorm.weight") != -1
or key.find("final_layernorm.bias") != -1
):
@ -35,23 +32,25 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
val.tofile(saved_path)
elif (
key.find("attention.dense.bias") != -1
or key.find("mlp.dense_4h_to_h.bias") != -1
):
saved_path = saved_dir + f"/model.{key}.bin"
val = (val / factor) if factor > 1 else val
val.tofile(saved_path)
else:
if (
key.find("attention.dense.weight") != -1
or key.find("mlp.dense_4h_to_h.weight") != -1
):
split_vals = np.split(val, factor, axis=0)
for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path)
elif (
key.find("mlp.dense_h_to_4h.weight") != -1
or key.find("mlp.dense_h_to_4h.bias") != -1
):
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path)
elif key.find("attention.query_key_value.bias") != -1:
local_dim = (int)(val.shape[-1] / 3)
@ -61,10 +60,6 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
val = np.transpose(val, [1, 0, 2]).reshape(3, local_dim)
split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path)
elif key.find("attention.query_key_value.weight") != -1:
hidden_dim = val.shape[0]
local_dim = (int)(val.shape[-1] / 3)
@ -77,13 +72,14 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
# print(np.mean(np.abs(val[:, 0, :])))
split_vals = np.split(val, factor, axis=-1)
else:
print("[ERROR] cannot find key '{}'".format(key))
return
for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path)
else:
print("[ERROR] cannot find key '{}'".format(key))
def split_and_convert(args):
saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num
@ -91,11 +87,7 @@ def split_and_convert(args):
if os.path.exists(saved_dir) == False:
os.makedirs(saved_dir)
t_gpu_num = args.trained_gpu_num
i_gpu_num = args.infer_gpu_num
assert i_gpu_num % t_gpu_num == 0
factor = (int)(i_gpu_num / t_gpu_num)
factor = args.infer_gpu_num
# load position_embedding from rank 0
# model = torch.load(ckpt_name)
@ -145,8 +137,20 @@ def split_and_convert(args):
"mlp.dense_4h_to_h.weight",
]
huggingface_model_file_list = [
hf_file_name
for hf_file_name in os.listdir(args.in_file)
if hf_file_name.endswith(".bin")
]
if len(huggingface_model_file_list) > 1:
multiprocessing_context = multiprocessing.get_context()
pool_fn = multiprocessing_context.Pool
else:
torch.multiprocessing.set_start_method("spawn")
pool = multiprocessing.Pool(args.processes)
pool_fn = multiprocessing.Pool
pool = pool_fn(args.processes)
for name, param in model.named_parameters():
array = param.detach().cpu().numpy().astype(np_weight_data_type)
# print("input shape", name, array.shape)
@ -217,13 +221,6 @@ if __name__ == "__main__":
help="file name of input checkpoint file",
required=True,
)
parser.add_argument(
"-trained_gpu_num",
"-t_g",
type=int,
help="How many gpus for inference",
default=1,
)
parser.add_argument(
"-infer_gpu_num",
"-i_g",
@ -251,5 +248,9 @@ if __name__ == "__main__":
print("{}: {}".format(key, vars(args)[key]))
print("========================================")
shutil.rmtree(args.saved_dir, ignore_errors=True)
target_dir_path = os.path.join(args.saved_dir, "%d-gpu" % args.infer_gpu_num)
assert not os.path.exists(target_dir_path), (
"target path has exist, please remove %s first." % target_dir_path
)
split_and_convert(args)