feat: support load_in_8bits in python backend

add-more-languages
Meng Zhang 2023-04-02 11:54:16 +08:00
parent 82103e7280
commit 0a30165862
1 changed files with 5 additions and 4 deletions

View File

@ -15,14 +15,14 @@ from .utils import random_completion_id, trim_with_stop_words
class PythonModelService:
def __init__(
self,
model_name,
):
def __init__(self, model_name, quantize=False):
if torch.cuda.is_available():
device = torch.device("cuda")
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
else:
if quantize:
raise ValueError("quantization on CPU is not implemented")
device = torch.device("cpu")
dtype = torch.float32
@ -36,6 +36,7 @@ class PythonModelService:
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
local_files_only=True,
load_in_8bit=quantize,
)
.to(device)
.eval()