Integrate triton service with server
parent
0a5de72191
commit
0d8d7097be
|
|
@ -1,7 +1,10 @@
|
||||||
|
import os
|
||||||
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI, Response
|
from fastapi import FastAPI, Response
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from models import CompletionsRequest, CompletionsResponse
|
from models import CompletionsRequest, CompletionsResponse
|
||||||
|
from triton import TritonService
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="TabbyServer",
|
title="TabbyServer",
|
||||||
|
|
@ -9,15 +12,12 @@ app = FastAPI(
|
||||||
docs_url="/",
|
docs_url="/",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
triton = TritonService(os.environ["TOKENIZER_NAME"])
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions")
|
@app.post("/v1/completions")
|
||||||
async def completions(data: CompletionsRequest) -> CompletionsResponse:
|
async def completions(data: CompletionsRequest) -> CompletionsResponse:
|
||||||
return CompletionsResponse()
|
return triton(data)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/completions/{id}/choices/{index}/select")
|
|
||||||
async def select(id: str, index: int):
|
|
||||||
return JSONResponse(content=dict(status="ok"))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -11,14 +11,12 @@ class Choice(BaseModel):
|
||||||
|
|
||||||
class CompletionsRequest(BaseModel):
|
class CompletionsRequest(BaseModel):
|
||||||
prompt: str = Field(
|
prompt: str = Field(
|
||||||
example="def fib(n):",
|
example="def binarySearch(arr, left, right, x):\n mid = (left +",
|
||||||
description="The context to generate completions for, encoded as a string.",
|
description="The context to generate completions for, encoded as a string.",
|
||||||
)
|
)
|
||||||
suffix: Optional[str] = Field(
|
|
||||||
description="The suffix that comes after a completion of inserted code."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CompletionsResponse(BaseModel):
|
class CompletionsResponse(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
|
created: int
|
||||||
choices: List[Choice]
|
choices: List[Choice]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,127 @@
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
import string
|
||||||
|
import time
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import tritonclient.grpc as client_util
|
||||||
|
from models import Choice, CompletionsRequest, CompletionsResponse
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
||||||
|
|
||||||
|
|
||||||
|
class TritonService:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
tokenizer_name,
|
||||||
|
host: str = "localhost",
|
||||||
|
port: int = 8001,
|
||||||
|
verbose: bool = False,
|
||||||
|
):
|
||||||
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||||
|
self.client = client_util.InferenceServerClient(
|
||||||
|
url=f"{host}:{port}", verbose=verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate(self, data: CompletionsRequest) -> List[Choice]:
|
||||||
|
n = 1
|
||||||
|
np_type = np.uint32
|
||||||
|
max_tokens = 128
|
||||||
|
model_name = "fastertransformer"
|
||||||
|
|
||||||
|
prompt = data.prompt
|
||||||
|
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
||||||
|
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
||||||
|
prompt_len = input_start_ids.shape[1]
|
||||||
|
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np_type)
|
||||||
|
|
||||||
|
prompt_tokens: int = input_len[0][0]
|
||||||
|
output_len = np.ones_like(input_len).astype(np_type) * max_tokens
|
||||||
|
|
||||||
|
stop_words = ["\n\n"]
|
||||||
|
stop_word_list = np.repeat(
|
||||||
|
to_word_list_format([stop_words], self.tokenizer),
|
||||||
|
input_start_ids.shape[0],
|
||||||
|
axis=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = [
|
||||||
|
prepare_tensor("input_ids", input_start_ids),
|
||||||
|
prepare_tensor("input_lengths", input_len),
|
||||||
|
prepare_tensor("request_output_len", output_len),
|
||||||
|
prepare_tensor("stop_words_list", stop_word_list),
|
||||||
|
]
|
||||||
|
|
||||||
|
result = self.client.infer(model_name, inputs)
|
||||||
|
|
||||||
|
output_data = result.as_numpy("output_ids")
|
||||||
|
if output_data is None:
|
||||||
|
raise RuntimeError("No output data")
|
||||||
|
|
||||||
|
output_data = output_data.squeeze(1)
|
||||||
|
sequence_lengths = result.as_numpy("sequence_length").squeeze(1)
|
||||||
|
gen_len = sequence_lengths - input_len.squeeze(1)
|
||||||
|
|
||||||
|
decoded = [
|
||||||
|
self.tokenizer.decode(out[prompt_len : prompt_len + g])
|
||||||
|
for g, out in zip(gen_len, output_data)
|
||||||
|
]
|
||||||
|
trimmed = [trim_with_stopwords(d, stop_words) for d in decoded]
|
||||||
|
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
|
||||||
|
|
||||||
|
def __call__(self, data: CompletionsRequest) -> CompletionsResponse:
|
||||||
|
choices = self.generate(data)
|
||||||
|
return CompletionsResponse(
|
||||||
|
id=random_completion_id(), created=int(time.time()), choices=choices
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_tensor(name: str, tensor_input):
|
||||||
|
t = client_util.InferInput(
|
||||||
|
name, tensor_input.shape, np_to_triton_dtype(tensor_input.dtype)
|
||||||
|
)
|
||||||
|
t.set_data_from_numpy(tensor_input)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
|
def random_completion_id():
|
||||||
|
return "cmpl-" + "".join(
|
||||||
|
random.choice(string.ascii_letters + string.digits) for _ in range(29)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def trim_with_stopwords(output: str, stopwords: list) -> str:
|
||||||
|
for w in sorted(stopwords, key=len, reverse=True):
|
||||||
|
if output.endswith(w):
|
||||||
|
output = output[: -len(w)]
|
||||||
|
break
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def to_word_list_format(word_dict, tokenizer):
|
||||||
|
flat_ids = []
|
||||||
|
offsets = []
|
||||||
|
for word_dict_item in word_dict:
|
||||||
|
item_flat_ids = []
|
||||||
|
item_offsets = []
|
||||||
|
|
||||||
|
for word in word_dict_item:
|
||||||
|
ids = tokenizer.encode(word)
|
||||||
|
|
||||||
|
if len(ids) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
item_flat_ids += ids
|
||||||
|
item_offsets.append(len(ids))
|
||||||
|
|
||||||
|
flat_ids.append(np.array(item_flat_ids))
|
||||||
|
offsets.append(np.cumsum(np.array(item_offsets)))
|
||||||
|
|
||||||
|
pad_to = max(1, max(len(ids) for ids in flat_ids))
|
||||||
|
|
||||||
|
for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
|
||||||
|
flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
|
||||||
|
offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
|
||||||
|
|
||||||
|
return np.array([flat_ids, offsets], dtype=np.int32).transpose((1, 0, 2))
|
||||||
Loading…
Reference in New Issue