feat: support load_in_8bits in python backend
parent
82103e7280
commit
0a30165862
|
|
@ -15,14 +15,14 @@ from .utils import random_completion_id, trim_with_stop_words
|
||||||
|
|
||||||
|
|
||||||
class PythonModelService:
|
class PythonModelService:
|
||||||
def __init__(
|
def __init__(self, model_name, quantize=False):
|
||||||
self,
|
|
||||||
model_name,
|
|
||||||
):
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||||
else:
|
else:
|
||||||
|
if quantize:
|
||||||
|
raise ValueError("quantization on CPU is not implemented")
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
dtype = torch.float32
|
dtype = torch.float32
|
||||||
|
|
||||||
|
|
@ -36,6 +36,7 @@ class PythonModelService:
|
||||||
torch_dtype=dtype,
|
torch_dtype=dtype,
|
||||||
device_map="auto" if torch.cuda.is_available() else None,
|
device_map="auto" if torch.cuda.is_available() else None,
|
||||||
local_files_only=True,
|
local_files_only=True,
|
||||||
|
load_in_8bit=quantize,
|
||||||
)
|
)
|
||||||
.to(device)
|
.to(device)
|
||||||
.eval()
|
.eval()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue