refactor: remove unused python code, move trainer to python/
parent
0f72788d82
commit
2691b302f0
|
|
@ -1,69 +0,0 @@
|
|||
import os
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from . import events as events_lib
|
||||
from .backend import PythonModelService, TritonService
|
||||
from .models import (
|
||||
ChoiceEvent,
|
||||
CompletionEvent,
|
||||
CompletionRequest,
|
||||
CompletionResponse,
|
||||
EventTypeMapping,
|
||||
)
|
||||
|
||||
app = FastAPI(
|
||||
title="TabbyServer",
|
||||
description="""
|
||||
[](http://github.com/TabbyML/tabby)
|
||||
|
||||
TabbyServer is the backend for tabby, serving code completion requests from code editor / IDE.
|
||||
|
||||
* [Admin Panel](./_admin)
|
||||
""",
|
||||
docs_url="/",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=False,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
MODEL_NAME = os.environ.get("MODEL_NAME")
|
||||
MODEL_BACKEND = os.environ.get("MODEL_BACKEND", "python")
|
||||
|
||||
if MODEL_BACKEND == "triton":
|
||||
model_backend = TritonService(
|
||||
tokenizer_name=MODEL_NAME,
|
||||
host=os.environ.get("TRITON_HOST", "localhost"),
|
||||
port=os.environ.get("TRITON_PORT", "8001"),
|
||||
)
|
||||
else:
|
||||
model_backend = PythonModelService(MODEL_NAME)
|
||||
|
||||
LOGS_DIR = os.environ.get("LOGS_DIR", None)
|
||||
if LOGS_DIR is not None:
|
||||
events_lib.setup_logging(os.path.join(LOGS_DIR, "tabby-server"))
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
async def completions(request: CompletionRequest) -> CompletionResponse:
|
||||
response = model_backend(request)
|
||||
events_lib.log_completion(request, response)
|
||||
return response
|
||||
|
||||
|
||||
@app.post("/v1/events")
|
||||
async def events(e: ChoiceEvent | CompletionEvent) -> JSONResponse:
|
||||
if isinstance(e, EventTypeMapping[e.type]):
|
||||
events_lib.log_event(e)
|
||||
return JSONResponse(content="ok")
|
||||
else:
|
||||
print(type(e))
|
||||
return JSONResponse(content="invalid event", status_code=422)
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
from .python import PythonModelService
|
||||
from .triton import TritonService
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
from typing import List, Optional, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from ..models import Language
|
||||
|
||||
|
||||
class LanguagePreset(BaseModel):
|
||||
max_length: int
|
||||
stop_words: List[str]
|
||||
reserved_keywords: Optional[Set]
|
||||
|
||||
|
||||
LanguagePresets = {
|
||||
Language.UNKNOWN: LanguagePreset(
|
||||
max_length=128,
|
||||
stop_words=["\n\n"],
|
||||
),
|
||||
Language.PYTHON: LanguagePreset(
|
||||
max_length=128,
|
||||
stop_words=["\n\n", "\ndef", "\n#", "\nimport", "\nfrom", "\nclass"],
|
||||
reserved_keywords=set(
|
||||
[
|
||||
"False",
|
||||
"class",
|
||||
"from",
|
||||
"or",
|
||||
"None",
|
||||
"continue",
|
||||
"global",
|
||||
"pass",
|
||||
"True",
|
||||
"def",
|
||||
"if",
|
||||
"raise",
|
||||
"and",
|
||||
"del",
|
||||
"import",
|
||||
"return",
|
||||
"as",
|
||||
"elif",
|
||||
"in",
|
||||
"try",
|
||||
"assert",
|
||||
"else",
|
||||
"is",
|
||||
"while",
|
||||
"async",
|
||||
"except",
|
||||
"lambda",
|
||||
"with",
|
||||
"await",
|
||||
"finally",
|
||||
"nonlocal",
|
||||
"yield",
|
||||
"break",
|
||||
"for",
|
||||
"not",
|
||||
]
|
||||
),
|
||||
),
|
||||
Language.JAVASCRIPT: LanguagePreset(
|
||||
max_length=128, stop_words=["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"]
|
||||
),
|
||||
Language.TYPESCRIPT: LanguagePreset(
|
||||
max_length=128,
|
||||
stop_words=[
|
||||
"\n\n",
|
||||
"\nfunction",
|
||||
"\n//",
|
||||
"\nimport",
|
||||
"\nclass",
|
||||
"\ninterface",
|
||||
"\ntype",
|
||||
],
|
||||
),
|
||||
}
|
||||
|
|
@ -1,67 +0,0 @@
|
|||
import time
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
StoppingCriteria,
|
||||
StoppingCriteriaList,
|
||||
)
|
||||
|
||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||
from .language_presets import LanguagePresets
|
||||
from .utils import random_completion_id, trim_with_stop_words
|
||||
|
||||
|
||||
class PythonModelService:
|
||||
def __init__(self, model_name, quantize=False):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda")
|
||||
dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
|
||||
else:
|
||||
if quantize:
|
||||
raise ValueError("quantization on CPU is not implemented")
|
||||
|
||||
device = torch.device("cpu")
|
||||
dtype = torch.float32
|
||||
|
||||
self.device = device
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_name, local_files_only=True
|
||||
)
|
||||
self.model = (
|
||||
AutoModelForCausalLM.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=dtype,
|
||||
device_map="auto" if torch.cuda.is_available() else None,
|
||||
local_files_only=True,
|
||||
load_in_8bit=quantize,
|
||||
)
|
||||
.to(device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
def generate(self, request: CompletionRequest) -> List[Choice]:
|
||||
preset = LanguagePresets.get(request.language, None)
|
||||
if preset is None:
|
||||
return []
|
||||
|
||||
input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt").to(
|
||||
self.device
|
||||
)
|
||||
res = self.model.generate(
|
||||
input_ids,
|
||||
max_length=preset.max_length,
|
||||
)
|
||||
output_ids = res[0][len(input_ids[0]) :]
|
||||
text = trim_with_stop_words(
|
||||
self.tokenizer.decode(output_ids), preset.stop_words
|
||||
)
|
||||
return [Choice(index=0, text=text)]
|
||||
|
||||
def __call__(self, request: CompletionRequest) -> CompletionResponse:
|
||||
choices = self.generate(request)
|
||||
return CompletionResponse(
|
||||
id=random_completion_id(), created=int(time.time()), choices=choices
|
||||
)
|
||||
|
|
@ -1,121 +0,0 @@
|
|||
import time
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import tritonclient.grpc as client_util
|
||||
from transformers import AutoTokenizer
|
||||
from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
||||
|
||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||
from .language_presets import LanguagePresets
|
||||
from .utils import random_completion_id, trim_with_stop_words
|
||||
|
||||
|
||||
class TritonService:
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer_name,
|
||||
host: str = "localhost",
|
||||
port: int = 8001,
|
||||
verbose: bool = False,
|
||||
):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
||||
self.client = client_util.InferenceServerClient(
|
||||
url=f"{host}:{port}", verbose=verbose
|
||||
)
|
||||
|
||||
def generate(self, data: CompletionRequest) -> List[Choice]:
|
||||
n = 1
|
||||
np_type = np.uint32
|
||||
model_name = "fastertransformer"
|
||||
|
||||
preset = LanguagePresets.get(data.language, None)
|
||||
if preset is None:
|
||||
return []
|
||||
|
||||
prompt = data.prompt
|
||||
|
||||
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
||||
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
||||
prompt_len = input_start_ids.shape[1]
|
||||
input_len = prompt_len * np.ones([input_start_ids.shape[0], 1]).astype(np_type)
|
||||
|
||||
prompt_tokens: int = input_len[0][0]
|
||||
output_len = np.ones_like(input_len).astype(np_type) * preset.max_length
|
||||
|
||||
stop_word_list = np.repeat(
|
||||
to_word_list_format([preset.stop_words], self.tokenizer),
|
||||
input_start_ids.shape[0],
|
||||
axis=0,
|
||||
)
|
||||
|
||||
inputs = [
|
||||
prepare_tensor("input_ids", input_start_ids),
|
||||
prepare_tensor("input_lengths", input_len),
|
||||
prepare_tensor("request_output_len", output_len),
|
||||
prepare_tensor("stop_words_list", stop_word_list),
|
||||
]
|
||||
|
||||
result = self.client.infer(model_name, inputs)
|
||||
|
||||
output_data = result.as_numpy("output_ids")
|
||||
if output_data is None:
|
||||
raise RuntimeError("No output data")
|
||||
|
||||
output_data = output_data.squeeze(1)
|
||||
sequence_lengths = result.as_numpy("sequence_length").squeeze(1)
|
||||
gen_len = sequence_lengths - input_len.squeeze(1)
|
||||
|
||||
decoded = [
|
||||
self.tokenizer.decode(out[prompt_len : prompt_len + g])
|
||||
for g, out in zip(gen_len, output_data)
|
||||
]
|
||||
trimmed = [trim_with_stop_words(d, preset.stop_words) for d in decoded]
|
||||
return [Choice(index=i, text=text) for i, text in enumerate(trimmed)]
|
||||
|
||||
def __call__(self, data: CompletionRequest) -> CompletionResponse:
|
||||
choices = self.generate(data)
|
||||
return CompletionResponse(
|
||||
id=random_completion_id(), created=int(time.time()), choices=choices
|
||||
)
|
||||
|
||||
|
||||
def prepare_tensor(name: str, tensor_input):
|
||||
t = client_util.InferInput(
|
||||
name, tensor_input.shape, np_to_triton_dtype(tensor_input.dtype)
|
||||
)
|
||||
t.set_data_from_numpy(tensor_input)
|
||||
return t
|
||||
|
||||
|
||||
def to_word_list_format(word_dict, tokenizer):
|
||||
flat_ids = []
|
||||
offsets = []
|
||||
for word_dict_item in word_dict:
|
||||
item_flat_ids = []
|
||||
item_offsets = []
|
||||
|
||||
for word in word_dict_item:
|
||||
ids = tokenizer.encode(word)
|
||||
|
||||
if len(ids) == 0:
|
||||
continue
|
||||
|
||||
item_flat_ids += ids
|
||||
item_offsets.append(len(ids))
|
||||
|
||||
if word == "\n\n":
|
||||
ids = tokenizer.encode("\n") * 2
|
||||
item_flat_ids += ids
|
||||
item_offsets.append(len(ids))
|
||||
|
||||
flat_ids.append(np.array(item_flat_ids))
|
||||
offsets.append(np.cumsum(np.array(item_offsets)))
|
||||
|
||||
pad_to = max(1, max(len(ids) for ids in flat_ids))
|
||||
|
||||
for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
|
||||
flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
|
||||
offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
|
||||
|
||||
return np.array([flat_ids, offsets], dtype=np.int32).transpose((1, 0, 2))
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
import random
|
||||
import string
|
||||
|
||||
|
||||
def random_completion_id():
|
||||
return "cmpl-" + "".join(
|
||||
random.choice(string.ascii_letters + string.digits) for _ in range(29)
|
||||
)
|
||||
|
||||
|
||||
def trim_with_stop_words(output: str, stopwords: list) -> str:
|
||||
for w in sorted(stopwords, key=len, reverse=True):
|
||||
index = output.find(w)
|
||||
if index != -1:
|
||||
output = output[:index]
|
||||
return output
|
||||
|
|
@ -1,37 +0,0 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import models
|
||||
|
||||
|
||||
def setup_logging(logdir):
|
||||
try:
|
||||
shutil.rmtree(logdir)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# Remove default handler
|
||||
logger.add(
|
||||
os.path.join(logdir, "events.{time}.log"),
|
||||
rotation="1 hours",
|
||||
retention="2 hours",
|
||||
level="INFO",
|
||||
filter=__name__,
|
||||
enqueue=True,
|
||||
delay=True,
|
||||
serialize=True,
|
||||
)
|
||||
|
||||
|
||||
def log_completion(
|
||||
request: models.CompletionRequest, response: models.CompletionResponse
|
||||
) -> None:
|
||||
event = models.CompletionEvent.build(request, response)
|
||||
logger.info(event.json())
|
||||
|
||||
|
||||
def log_event(event: models.Event):
|
||||
logger.info(event.json())
|
||||
|
|
@ -1,77 +0,0 @@
|
|||
from enum import Enum
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
index: int
|
||||
text: str
|
||||
|
||||
|
||||
# https://code.visualstudio.com/docs/languages/identifiers
|
||||
class Language(str, Enum):
|
||||
UNKNOWN = "unknown"
|
||||
PYTHON = "python"
|
||||
JAVASCRIPT = "javascript"
|
||||
TYPESCRIPT = "typescript"
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
language: str = Field(
|
||||
example=Language.PYTHON,
|
||||
default=Language.UNKNOWN,
|
||||
description="Language for completion request",
|
||||
)
|
||||
|
||||
prompt: str = Field(
|
||||
example="def binarySearch(arr, left, right, x):\n mid = (left +",
|
||||
description="The context to generate completions for, encoded as a string.",
|
||||
)
|
||||
|
||||
|
||||
class CompletionResponse(BaseModel):
|
||||
id: str
|
||||
created: int
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class EventType(str, Enum):
|
||||
COMPLETION = "completion"
|
||||
VIEW = "view"
|
||||
SELECT = "select"
|
||||
|
||||
|
||||
class Event(BaseModel):
|
||||
type: EventType
|
||||
|
||||
|
||||
class CompletionEvent(Event):
|
||||
id: str
|
||||
language: str
|
||||
prompt: str
|
||||
created: int
|
||||
choices: List[Choice]
|
||||
|
||||
@classmethod
|
||||
def build(cls, request: CompletionRequest, response: CompletionResponse):
|
||||
return cls(
|
||||
type=EventType.COMPLETION,
|
||||
id=response.id,
|
||||
language=request.language,
|
||||
prompt=request.prompt,
|
||||
created=response.created,
|
||||
choices=response.choices,
|
||||
)
|
||||
|
||||
|
||||
class ChoiceEvent(Event):
|
||||
completion_id: str
|
||||
choice_index: int
|
||||
|
||||
|
||||
EventTypeMapping = {
|
||||
EventType.COMPLETION: CompletionEvent,
|
||||
EventType.VIEW: ChoiceEvent,
|
||||
EventType.SELECT: ChoiceEvent,
|
||||
}
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
schedule: "*/3 * * * *" # Run every 3rd minute
|
||||
|
||||
env:
|
||||
- PATH: "$PATH"
|
||||
- LOGS_DIR: "$LOGS_DIR"
|
||||
- DB_FILE: "$DB_FILE"
|
||||
- APP_DIR: /home/app
|
||||
steps:
|
||||
- name: Collect Tabby
|
||||
dir: $APP_DIR
|
||||
command: ./tabby/tools/analytic/main.sh collect_tabby_server_logs
|
||||
|
|
@ -1,19 +0,0 @@
|
|||
schedule: "5 4 * * *" # Run daily at 04:05.
|
||||
|
||||
env:
|
||||
- PATH: "$PATH"
|
||||
- APP_DIR: /home/app
|
||||
- CONFIG_FILE: "$CONFIG_FILE"
|
||||
- GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR"
|
||||
- DATASET_DIR: "$DATASET_DIR"
|
||||
- HOME: "$HOME"
|
||||
steps:
|
||||
- name: update repositories
|
||||
dir: $APP_DIR
|
||||
command: python -m tabby.tools.repository.updater --data_dir=$GIT_REPOSITORIES_DIR --config_file=$CONFIG_FILE
|
||||
|
||||
- name: generate dataset
|
||||
dir: $APP_DIR
|
||||
command: python -m tabby.tools.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
|
||||
depends:
|
||||
- update repositories
|
||||
Loading…
Reference in New Issue