129 lines
4.1 KiB
Python
129 lines
4.1 KiB
Python
import json
|
|
import random
|
|
import string
|
|
import time
|
|
from typing import List
|
|
|
|
import numpy as np
|
|
import tritonclient.grpc as client_util
|
|
from models import Choice, CompletionsRequest, CompletionsResponse
|
|
from transformers import AutoTokenizer
|
|
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
|
|
|
|
|
class TritonService:
|
|
def __init__(
|
|
self,
|
|
tokenizer_name,
|
|
host: str = "localhost",
|
|
port: int = 8001,
|
|
verbose: bool = False,
|
|
):
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
|
self.client = client_util.InferenceServerClient(
|
|
url=f"{host}:{port}", verbose=verbose
|
|
)
|
|
|
|
def generate(self, data: CompletionsRequest) -> List[Choice]:
|
|
# FIXME(meng): Make following vars configurable
|
|
n = 1
|
|
np_type = np.uint32
|
|
max_tokens = 128
|
|
model_name = "fastertransformer"
|
|
stop_words = ["\n\n"]
|
|
|
|
prompt = data.prompt
|
|
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
|
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
|
prompt_len = input_start_ids.shape[1]
|
|
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np_type)
|
|
|
|
prompt_tokens: int = input_len[0][0]
|
|
output_len = np.ones_like(input_len).astype(np_type) * max_tokens
|
|
|
|
stop_word_list = np.repeat(
|
|
to_word_list_format([stop_words], self.tokenizer),
|
|
input_start_ids.shape[0],
|
|
axis=0,
|
|
)
|
|
|
|
inputs = [
|
|
prepare_tensor("input_ids", input_start_ids),
|
|
prepare_tensor("input_lengths", input_len),
|
|
prepare_tensor("request_output_len", output_len),
|
|
prepare_tensor("stop_words_list", stop_word_list),
|
|
]
|
|
|
|
result = self.client.infer(model_name, inputs)
|
|
|
|
output_data = result.as_numpy("output_ids")
|
|
if output_data is None:
|
|
raise RuntimeError("No output data")
|
|
|
|
output_data = output_data.squeeze(1)
|
|
sequence_lengths = result.as_numpy("sequence_length").squeeze(1)
|
|
gen_len = sequence_lengths - input_len.squeeze(1)
|
|
|
|
decoded = [
|
|
self.tokenizer.decode(out[prompt_len : prompt_len + g])
|
|
for g, out in zip(gen_len, output_data)
|
|
]
|
|
trimmed = [trim_with_stopwords(d, stop_words) for d in decoded]
|
|
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
|
|
|
|
def __call__(self, data: CompletionsRequest) -> CompletionsResponse:
|
|
choices = self.generate(data)
|
|
return CompletionsResponse(
|
|
id=random_completion_id(), created=int(time.time()), choices=choices
|
|
)
|
|
|
|
|
|
def prepare_tensor(name: str, tensor_input):
|
|
t = client_util.InferInput(
|
|
name, tensor_input.shape, np_to_triton_dtype(tensor_input.dtype)
|
|
)
|
|
t.set_data_from_numpy(tensor_input)
|
|
return t
|
|
|
|
|
|
def random_completion_id():
|
|
return "cmpl-" + "".join(
|
|
random.choice(string.ascii_letters + string.digits) for _ in range(29)
|
|
)
|
|
|
|
|
|
def trim_with_stopwords(output: str, stopwords: list) -> str:
|
|
for w in sorted(stopwords, key=len, reverse=True):
|
|
if output.endswith(w):
|
|
output = output[: -len(w)]
|
|
break
|
|
return output
|
|
|
|
|
|
def to_word_list_format(word_dict, tokenizer):
|
|
flat_ids = []
|
|
offsets = []
|
|
for word_dict_item in word_dict:
|
|
item_flat_ids = []
|
|
item_offsets = []
|
|
|
|
for word in word_dict_item:
|
|
ids = tokenizer.encode(word)
|
|
|
|
if len(ids) == 0:
|
|
continue
|
|
|
|
item_flat_ids += ids
|
|
item_offsets.append(len(ids))
|
|
|
|
flat_ids.append(np.array(item_flat_ids))
|
|
offsets.append(np.cumsum(np.array(item_offsets)))
|
|
|
|
pad_to = max(1, max(len(ids) for ids in flat_ids))
|
|
|
|
for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
|
|
flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
|
|
offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
|
|
|
|
return np.array([flat_ids, offsets], dtype=np.int32).transpose((1, 0, 2))
|