tabby/server/triton.py

129 lines
4.1 KiB
Python

import json
import random
import string
import time
from typing import List
import numpy as np
import tritonclient.grpc as client_util
from models import Choice, CompletionsRequest, CompletionsResponse
from transformers import AutoTokenizer
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
class TritonService:
def __init__(
self,
tokenizer_name,
host: str = "localhost",
port: int = 8001,
verbose: bool = False,
):
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
self.client = client_util.InferenceServerClient(
url=f"{host}:{port}", verbose=verbose
)
def generate(self, data: CompletionsRequest) -> List[Choice]:
# FIXME(meng): Make following vars configurable
n = 1
np_type = np.uint32
max_tokens = 128
model_name = "fastertransformer"
stop_words = ["\n\n"]
prompt = data.prompt
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
prompt_len = input_start_ids.shape[1]
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np_type)
prompt_tokens: int = input_len[0][0]
output_len = np.ones_like(input_len).astype(np_type) * max_tokens
stop_word_list = np.repeat(
to_word_list_format([stop_words], self.tokenizer),
input_start_ids.shape[0],
axis=0,
)
inputs = [
prepare_tensor("input_ids", input_start_ids),
prepare_tensor("input_lengths", input_len),
prepare_tensor("request_output_len", output_len),
prepare_tensor("stop_words_list", stop_word_list),
]
result = self.client.infer(model_name, inputs)
output_data = result.as_numpy("output_ids")
if output_data is None:
raise RuntimeError("No output data")
output_data = output_data.squeeze(1)
sequence_lengths = result.as_numpy("sequence_length").squeeze(1)
gen_len = sequence_lengths - input_len.squeeze(1)
decoded = [
self.tokenizer.decode(out[prompt_len : prompt_len + g])
for g, out in zip(gen_len, output_data)
]
trimmed = [trim_with_stopwords(d, stop_words) for d in decoded]
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
def __call__(self, data: CompletionsRequest) -> CompletionsResponse:
choices = self.generate(data)
return CompletionsResponse(
id=random_completion_id(), created=int(time.time()), choices=choices
)
def prepare_tensor(name: str, tensor_input):
t = client_util.InferInput(
name, tensor_input.shape, np_to_triton_dtype(tensor_input.dtype)
)
t.set_data_from_numpy(tensor_input)
return t
def random_completion_id():
return "cmpl-" + "".join(
random.choice(string.ascii_letters + string.digits) for _ in range(29)
)
def trim_with_stopwords(output: str, stopwords: list) -> str:
for w in sorted(stopwords, key=len, reverse=True):
if output.endswith(w):
output = output[: -len(w)]
break
return output
def to_word_list_format(word_dict, tokenizer):
flat_ids = []
offsets = []
for word_dict_item in word_dict:
item_flat_ids = []
item_offsets = []
for word in word_dict_item:
ids = tokenizer.encode(word)
if len(ids) == 0:
continue
item_flat_ids += ids
item_offsets.append(len(ids))
flat_ids.append(np.array(item_flat_ids))
offsets.append(np.cumsum(np.array(item_offsets)))
pad_to = max(1, max(len(ids) for ids in flat_ids))
for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
return np.array([flat_ids, offsets], dtype=np.int32).transpose((1, 0, 2))