diff --git a/Makefile b/Makefile index e2b281c..ae07799 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,7 @@ POETRY_EXISTS := $(shell which poetry &> /dev/null) +PRE_COMMIT_HOOK := .git/hooks/pre-commit LOCAL_MODEL := testdata/tiny-70M/models/fastertransformer/1 -all: - pre-commit: poetry run pre-commit @@ -10,6 +9,10 @@ install-poetry: ifndef POETRY_EXISTS curl -sSL https://install.python-poetry.org | POETRY_VERSION=1.4.0 python3 - endif + poetry install + +$(PRE_COMMIT_HOOK): + poetry run pre-commit install --install-hooks $(LOCAL_MODEL): poetry run python converter/huggingface_gptneox_convert.py \ @@ -17,10 +20,16 @@ $(LOCAL_MODEL): -o $@ \ -i_g 1 -m_n tiny-70M -p 1 -w fp16 -setup-development-environment: install-poetry $(LOCAL_MODEL) +setup-development-environment: install-poetry $(PRE_COMMIT_HOOK) -up: $(LOCAL_MODEL) +up: docker-compose -f deployment/docker-compose.yml up -dev: $(setup-development-environment) $(LOCAL_MODEL) +up-triton: $(LOCAL_MODEL) + docker-compose -f deployment/docker-compose.yml -f deployment/docker-compose.triton.yml up + +dev: docker-compose -f deployment/docker-compose.yml -f deployment/docker-compose.dev.yml up --build + +dev-triton: $(LOCAL_MODEL) + docker-compose -f deployment/docker-compose.yml -f deployment/docker-compose.triton.yml -f deployment/docker-compose.dev.yml up --build diff --git a/README.md b/README.md index 58883c2..b05ac5e 100644 --- a/README.md +++ b/README.md @@ -23,4 +23,4 @@ Assuming Linux workstation with: 2. docker w/ gpu driver 3. python 3.10 -Use `make dev` to start local dev server. +Use `make setup-development-environment` to setup basic dev environment, and `make dev` to start local development server. diff --git a/deployment/.gitignore b/deployment/.gitignore index 98d8a5a..f9a1874 100644 --- a/deployment/.gitignore +++ b/deployment/.gitignore @@ -1 +1,2 @@ logs +hf_cache diff --git a/deployment/docker-compose.triton.yml b/deployment/docker-compose.triton.yml new file mode 100644 index 0000000..f2c91d6 --- /dev/null +++ b/deployment/docker-compose.triton.yml @@ -0,0 +1,33 @@ +version: '3.3' + +services: + server: + image: tabbyml/tabby + environment: + - MODEL_BACKEND=triton + - TRITON_TOKENIZER_NAME=/tokenizer + volumes: + - ../testdata/tiny-70M/tokenizer:/tokenizer + links: + - triton + + admin: + links: + - triton + + + + triton: + image: tabbyml/fastertransformer_backend + container_name: tabby-triton + command: mpirun -n 1 --allow-run-as-root /opt/tritonserver/bin/tritonserver --model-repository=/model + shm_size: 1gb + volumes: + - ../testdata/tiny-70M/models:/model + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] diff --git a/deployment/docker-compose.yml b/deployment/docker-compose.yml index e2f4c6c..4b0e441 100644 --- a/deployment/docker-compose.yml +++ b/deployment/docker-compose.yml @@ -7,33 +7,16 @@ services: working_dir: /app/server command: uvicorn app:app --host 0.0.0.0 --port 5000 environment: - - TOKENIZER_NAME=/tokenizer - - TRITON_HOST=triton + - PYTHON_MODEL_NAME=EleutherAI/pythia-70m-deduped - EVENTS_LOG_DIR=/logs/tabby-server ports: - "5000:5000" volumes: - ./logs:/logs - - ../testdata/tiny-70M/tokenizer:/tokenizer + - ./hf_cache:/root/.cache/huggingface links: - - triton - vector - triton: - image: tabbyml/fastertransformer_backend - container_name: tabby-triton - command: mpirun -n 1 --allow-run-as-root /opt/tritonserver/bin/tritonserver --model-repository=/model - shm_size: 1gb - volumes: - - ../testdata/tiny-70M/models:/model - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: all - capabilities: [gpu] - admin: image: tabbyml/tabby container_name: tabby-admin @@ -43,7 +26,6 @@ services: - "8501:8501" links: - server - - triton - vector vector: diff --git a/server/app.py b/server/app.py index f5191b3..8ac9cec 100644 --- a/server/app.py +++ b/server/app.py @@ -6,6 +6,7 @@ import uvicorn from fastapi import FastAPI, Response from fastapi.responses import JSONResponse from models import CompletionRequest, CompletionResponse +from python import PythonModelService from triton import TritonService app = FastAPI( @@ -14,16 +15,21 @@ app = FastAPI( docs_url="/", ) -triton = TritonService( - tokenizer_name=os.environ.get("TOKENIZER_NAME", None), - host=os.environ.get("TRITON_HOST", "localhost"), - port=os.environ.get("TRITON_PORT", "8001"), -) +MODEL_BACKEND = os.environ.get("MODEL_BACKEND", "python") + +if MODEL_BACKEND == "triton": + model_backend = TritonService( + tokenizer_name=os.environ.get("TRITON_TOKENIZER_NAME", None), + host=os.environ.get("TRITON_HOST", "triton"), + port=os.environ.get("TRITON_PORT", "8001"), + ) +else: + model_backend = PythonModelService(os.environ["PYTHON_MODEL_NAME"]) @app.post("/v1/completions") async def completions(request: CompletionRequest) -> CompletionResponse: - response = triton(request) + response = model_backend(request) events.log_completions(request, response) return response diff --git a/server/python.py b/server/python.py new file mode 100644 index 0000000..bd08fc3 --- /dev/null +++ b/server/python.py @@ -0,0 +1,35 @@ +import random +import string +import time +from typing import List + +from models import Choice, CompletionRequest, CompletionResponse +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class PythonModelService: + def __init__( + self, + model_name, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_name) + self.model = AutoModelForCausalLM.from_pretrained(model_name) + + def generate(self, request: CompletionRequest) -> List[Choice]: + input_ids = self.tokenizer.encode(request.prompt, return_tensors="pt") + res = self.model.generate(input_ids, max_length=64) + output_ids = res[0][len(input_ids[0]) :] + text = self.tokenizer.decode(output_ids) + return [Choice(index=0, text=text)] + + def __call__(self, request: CompletionRequest) -> CompletionResponse: + choices = self.generate(request) + return CompletionResponse( + id=random_completion_id(), created=int(time.time()), choices=choices + ) + + +def random_completion_id(): + return "cmpl-" + "".join( + random.choice(string.ascii_letters + string.digits) for _ in range(29) + )