diff --git a/tabby/server/backend/python.py b/tabby/server/backend/python.py index 40949dd..0af1d69 100644 --- a/tabby/server/backend/python.py +++ b/tabby/server/backend/python.py @@ -15,14 +15,14 @@ from .utils import random_completion_id, trim_with_stop_words class PythonModelService: - def __init__( - self, - model_name, - ): + def __init__(self, model_name, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: + if quantize: + raise ValueError("quantization on CPU is not implemented") + device = torch.device("cpu") dtype = torch.float32 @@ -36,6 +36,7 @@ class PythonModelService: torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, local_files_only=True, + load_in_8bit=quantize, ) .to(device) .eval()