From 0a301658624f64d86f81c6e5815e7f9866b49716 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Sun, 2 Apr 2023 11:54:16 +0800 Subject: [PATCH] feat: support load_in_8bits in python backend --- tabby/server/backend/python.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tabby/server/backend/python.py b/tabby/server/backend/python.py index 40949dd..0af1d69 100644 --- a/tabby/server/backend/python.py +++ b/tabby/server/backend/python.py @@ -15,14 +15,14 @@ from .utils import random_completion_id, trim_with_stop_words class PythonModelService: - def __init__( - self, - model_name, - ): + def __init__(self, model_name, quantize=False): if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 else: + if quantize: + raise ValueError("quantization on CPU is not implemented") + device = torch.device("cpu") dtype = torch.float32 @@ -36,6 +36,7 @@ class PythonModelService: torch_dtype=dtype, device_map="auto" if torch.cuda.is_available() else None, local_files_only=True, + load_in_8bit=quantize, ) .to(device) .eval()