feat: support load_in_8bits in python backend
parent
82103e7280
commit
0a30165862
|
|
@ -15,14 +15,14 @@ from .utils import random_completion_id, trim_with_stop_words
|
|||
|
||||
|
||||
class PythonModelService:
|
||||
def __init__(
|
||||
self,
|
||||
model_name,
|
||||
):
|
||||
def __init__(self, model_name, quantize=False):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization on CPU is not implemented")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
|
|
@ -36,6 +36,7 @@ class PythonModelService:
|
|||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
local_files_only=True,
|
||||
load_in_8bit=quantize,
|
||||
)
|
||||
.to(device)
|
||||
.eval()
|
||||
|
|
|
|||
Loading…
Reference in New Issue