update huggingface_gptneox_convert.py (#112)

* update huggingface_gptneox_convert.py

* fix format

* fix pre-commit
add-tracing
ADLIBS 2023-04-23 21:36:46 +08:00 committed by GitHub
parent 300fe8c2b0
commit 1b96c18ab8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 64 additions and 63 deletions

View File

@ -2,7 +2,6 @@ import argparse
import configparser import configparser
import multiprocessing import multiprocessing
import os import os
import shutil
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@ -24,10 +23,8 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
if ( if (
key.find("input_layernorm.weight") != -1 key.find("input_layernorm.weight") != -1
or key.find("input_layernorm.bias") != -1 or key.find("input_layernorm.bias") != -1
or key.find("attention.dense.bias") != -1
or key.find("post_attention_layernorm.weight") != -1 or key.find("post_attention_layernorm.weight") != -1
or key.find("post_attention_layernorm.bias") != -1 or key.find("post_attention_layernorm.bias") != -1
or key.find("mlp.dense_4h_to_h.bias") != -1
or key.find("final_layernorm.weight") != -1 or key.find("final_layernorm.weight") != -1
or key.find("final_layernorm.bias") != -1 or key.find("final_layernorm.bias") != -1
): ):
@ -35,23 +32,25 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
val.tofile(saved_path) val.tofile(saved_path)
elif ( elif (
key.find("attention.dense.bias") != -1
or key.find("mlp.dense_4h_to_h.bias") != -1
):
saved_path = saved_dir + f"/model.{key}.bin"
val = (val / factor) if factor > 1 else val
val.tofile(saved_path)
else:
if (
key.find("attention.dense.weight") != -1 key.find("attention.dense.weight") != -1
or key.find("mlp.dense_4h_to_h.weight") != -1 or key.find("mlp.dense_4h_to_h.weight") != -1
): ):
split_vals = np.split(val, factor, axis=0) split_vals = np.split(val, factor, axis=0)
for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path)
elif ( elif (
key.find("mlp.dense_h_to_4h.weight") != -1 key.find("mlp.dense_h_to_4h.weight") != -1
or key.find("mlp.dense_h_to_4h.bias") != -1 or key.find("mlp.dense_h_to_4h.bias") != -1
): ):
split_vals = np.split(val, factor, axis=-1) split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path)
elif key.find("attention.query_key_value.bias") != -1: elif key.find("attention.query_key_value.bias") != -1:
local_dim = (int)(val.shape[-1] / 3) local_dim = (int)(val.shape[-1] / 3)
@ -61,10 +60,6 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
val = np.transpose(val, [1, 0, 2]).reshape(3, local_dim) val = np.transpose(val, [1, 0, 2]).reshape(3, local_dim)
split_vals = np.split(val, factor, axis=-1) split_vals = np.split(val, factor, axis=-1)
for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path)
elif key.find("attention.query_key_value.weight") != -1: elif key.find("attention.query_key_value.weight") != -1:
hidden_dim = val.shape[0] hidden_dim = val.shape[0]
local_dim = (int)(val.shape[-1] / 3) local_dim = (int)(val.shape[-1] / 3)
@ -77,13 +72,14 @@ def split_and_convert_process(saved_dir, factor, key, args, config, val):
# print(np.mean(np.abs(val[:, 0, :]))) # print(np.mean(np.abs(val[:, 0, :])))
split_vals = np.split(val, factor, axis=-1) split_vals = np.split(val, factor, axis=-1)
else:
print("[ERROR] cannot find key '{}'".format(key))
return
for j in range(factor): for j in range(factor):
saved_path = saved_dir + f"/model.{key}.{j}.bin" saved_path = saved_dir + f"/model.{key}.{j}.bin"
split_vals[j].tofile(saved_path) split_vals[j].tofile(saved_path)
else:
print("[ERROR] cannot find key '{}'".format(key))
def split_and_convert(args): def split_and_convert(args):
saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num saved_dir = args.saved_dir + "/%d-gpu/" % args.infer_gpu_num
@ -91,11 +87,7 @@ def split_and_convert(args):
if os.path.exists(saved_dir) == False: if os.path.exists(saved_dir) == False:
os.makedirs(saved_dir) os.makedirs(saved_dir)
t_gpu_num = args.trained_gpu_num factor = args.infer_gpu_num
i_gpu_num = args.infer_gpu_num
assert i_gpu_num % t_gpu_num == 0
factor = (int)(i_gpu_num / t_gpu_num)
# load position_embedding from rank 0 # load position_embedding from rank 0
# model = torch.load(ckpt_name) # model = torch.load(ckpt_name)
@ -145,8 +137,20 @@ def split_and_convert(args):
"mlp.dense_4h_to_h.weight", "mlp.dense_4h_to_h.weight",
] ]
huggingface_model_file_list = [
hf_file_name
for hf_file_name in os.listdir(args.in_file)
if hf_file_name.endswith(".bin")
]
if len(huggingface_model_file_list) > 1:
multiprocessing_context = multiprocessing.get_context()
pool_fn = multiprocessing_context.Pool
else:
torch.multiprocessing.set_start_method("spawn") torch.multiprocessing.set_start_method("spawn")
pool = multiprocessing.Pool(args.processes) pool_fn = multiprocessing.Pool
pool = pool_fn(args.processes)
for name, param in model.named_parameters(): for name, param in model.named_parameters():
array = param.detach().cpu().numpy().astype(np_weight_data_type) array = param.detach().cpu().numpy().astype(np_weight_data_type)
# print("input shape", name, array.shape) # print("input shape", name, array.shape)
@ -217,13 +221,6 @@ if __name__ == "__main__":
help="file name of input checkpoint file", help="file name of input checkpoint file",
required=True, required=True,
) )
parser.add_argument(
"-trained_gpu_num",
"-t_g",
type=int,
help="How many gpus for inference",
default=1,
)
parser.add_argument( parser.add_argument(
"-infer_gpu_num", "-infer_gpu_num",
"-i_g", "-i_g",
@ -251,5 +248,9 @@ if __name__ == "__main__":
print("{}: {}".format(key, vars(args)[key])) print("{}: {}".format(key, vars(args)[key]))
print("========================================") print("========================================")
shutil.rmtree(args.saved_dir, ignore_errors=True) target_dir_path = os.path.join(args.saved_dir, "%d-gpu" % args.infer_gpu_num)
assert not os.path.exists(target_dir_path), (
"target path has exist, please remove %s first." % target_dir_path
)
split_and_convert(args) split_and_convert(args)