From 0d8d7097be0e9b12b1b86956d6d5e3956c90f232 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 20 Mar 2023 22:57:29 +0800 Subject: [PATCH] Integrate triton service with server --- server/app.py | 12 ++--- server/models.py | 6 +-- server/triton.py | 127 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 135 insertions(+), 10 deletions(-) create mode 100644 server/triton.py diff --git a/server/app.py b/server/app.py index 364aa3d..8d16348 100644 --- a/server/app.py +++ b/server/app.py @@ -1,7 +1,10 @@ +import os + import uvicorn from fastapi import FastAPI, Response from fastapi.responses import JSONResponse from models import CompletionsRequest, CompletionsResponse +from triton import TritonService app = FastAPI( title="TabbyServer", @@ -9,15 +12,12 @@ app = FastAPI( docs_url="/", ) +triton = TritonService(os.environ["TOKENIZER_NAME"]) + @app.post("/v1/completions") async def completions(data: CompletionsRequest) -> CompletionsResponse: - return CompletionsResponse() - - -@app.post("/v1/completions/{id}/choices/{index}/select") -async def select(id: str, index: int): - return JSONResponse(content=dict(status="ok")) + return triton(data) if __name__ == "__main__": diff --git a/server/models.py b/server/models.py index ded7951..31c1d88 100644 --- a/server/models.py +++ b/server/models.py @@ -11,14 +11,12 @@ class Choice(BaseModel): class CompletionsRequest(BaseModel): prompt: str = Field( - example="def fib(n):", + example="def binarySearch(arr, left, right, x):\n mid = (left +", description="The context to generate completions for, encoded as a string.", ) - suffix: Optional[str] = Field( - description="The suffix that comes after a completion of inserted code." - ) class CompletionsResponse(BaseModel): id: str + created: int choices: List[Choice] diff --git a/server/triton.py b/server/triton.py new file mode 100644 index 0000000..f854cdd --- /dev/null +++ b/server/triton.py @@ -0,0 +1,127 @@ +import json +import random +import string +import time +from typing import List + +import numpy as np +import tritonclient.grpc as client_util +from models import Choice, CompletionsRequest, CompletionsResponse +from transformers import AutoTokenizer +from tritonclient.utils import InferenceServerException, np_to_triton_dtype + + +class TritonService: + def __init__( + self, + tokenizer_name, + host: str = "localhost", + port: int = 8001, + verbose: bool = False, + ): + self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + self.client = client_util.InferenceServerClient( + url=f"{host}:{port}", verbose=verbose + ) + + def generate(self, data: CompletionsRequest) -> List[Choice]: + n = 1 + np_type = np.uint32 + max_tokens = 128 + model_name = "fastertransformer" + + prompt = data.prompt + input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0) + input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type) + prompt_len = input_start_ids.shape[1] + input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np_type) + + prompt_tokens: int = input_len[0][0] + output_len = np.ones_like(input_len).astype(np_type) * max_tokens + + stop_words = ["\n\n"] + stop_word_list = np.repeat( + to_word_list_format([stop_words], self.tokenizer), + input_start_ids.shape[0], + axis=0, + ) + + inputs = [ + prepare_tensor("input_ids", input_start_ids), + prepare_tensor("input_lengths", input_len), + prepare_tensor("request_output_len", output_len), + prepare_tensor("stop_words_list", stop_word_list), + ] + + result = self.client.infer(model_name, inputs) + + output_data = result.as_numpy("output_ids") + if output_data is None: + raise RuntimeError("No output data") + + output_data = output_data.squeeze(1) + sequence_lengths = result.as_numpy("sequence_length").squeeze(1) + gen_len = sequence_lengths - input_len.squeeze(1) + + decoded = [ + self.tokenizer.decode(out[prompt_len : prompt_len + g]) + for g, out in zip(gen_len, output_data) + ] + trimmed = [trim_with_stopwords(d, stop_words) for d in decoded] + return [Choice(index=i, text=text) for i, text in enumerate(trimmed)] + + def __call__(self, data: CompletionsRequest) -> CompletionsResponse: + choices = self.generate(data) + return CompletionsResponse( + id=random_completion_id(), created=int(time.time()), choices=choices + ) + + +def prepare_tensor(name: str, tensor_input): + t = client_util.InferInput( + name, tensor_input.shape, np_to_triton_dtype(tensor_input.dtype) + ) + t.set_data_from_numpy(tensor_input) + return t + + +def random_completion_id(): + return "cmpl-" + "".join( + random.choice(string.ascii_letters + string.digits) for _ in range(29) + ) + + +def trim_with_stopwords(output: str, stopwords: list) -> str: + for w in sorted(stopwords, key=len, reverse=True): + if output.endswith(w): + output = output[: -len(w)] + break + return output + + +def to_word_list_format(word_dict, tokenizer): + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + for word in word_dict_item: + ids = tokenizer.encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype=np.int32).transpose((1, 0, 2))