feat(experimental): update tabby.py to use debug_options.raw_prompt

r0.4
Meng Zhang 2023-10-21 17:27:43 -07:00
parent 1d31b33ccc
commit 2dcb5599b3
2 changed files with 36 additions and 23 deletions

File diff suppressed because one or more lines are too long

View File

@ -1,18 +1,12 @@
from pathlib import Path from pathlib import Path
import modal
from modal import Image, Mount, Secret, Stub, asgi_app, gpu, method from modal import Image, Mount, Secret, Stub, asgi_app, gpu, method
GPU_CONFIG = gpu.T4() GPU_CONFIG = gpu.T4()
MODEL_ID = "TabbyML/StarCoder-1B" MODEL_ID = "TabbyML/StarCoder-1B"
LAUNCH_FLAGS = [ LAUNCH_FLAGS = ["serve", "--model", MODEL_ID, "--port", "8000", "--device", "cuda"]
"serve",
"--model",
MODEL_ID,
"--port",
"8000",
"--device",
"cuda"
]
def download_model(): def download_model():
import subprocess import subprocess
@ -28,10 +22,15 @@ def download_model():
image = ( image = (
Image.from_registry("tabbyml/tabby:0.3.0", add_python="3.11") Image.from_registry(
"tabbyml/tabby@sha256:64d71ec4c7d9ae7269e6301ad4106baad70ee997408691a6af17d7186283a856",
add_python="3.11",
)
.dockerfile_commands("ENTRYPOINT []") .dockerfile_commands("ENTRYPOINT []")
.run_function(download_model) .run_function(download_model)
.pip_install("git+https://github.com/TabbyML/tabby.git#egg=tabby-python-client&subdirectory=clients/tabby-python-client") .pip_install(
"git+https://github.com/TabbyML/tabby.git#egg=tabby-python-client&subdirectory=experimental/eval/tabby-python-client"
)
) )
stub = Stub("tabby-" + MODEL_ID.split("/")[-1], image=image) stub = Stub("tabby-" + MODEL_ID.split("/")[-1], image=image)
@ -49,11 +48,9 @@ class Model:
import subprocess import subprocess
import time import time
from tabby_client import Client from tabby_python_client import Client
self.launcher = subprocess.Popen( self.launcher = subprocess.Popen(["/opt/tabby/bin/tabby"] + LAUNCH_FLAGS)
["/opt/tabby/bin/tabby"] + LAUNCH_FLAGS
)
self.client = Client("http://127.0.0.1:8000") self.client = Client("http://127.0.0.1:8000")
# Poll until webserver at 127.0.0.1:8000 accepts connections before running inputs. # Poll until webserver at 127.0.0.1:8000 accepts connections before running inputs.
@ -79,15 +76,29 @@ class Model:
def __exit__(self, _exc_type, _exc_value, _traceback): def __exit__(self, _exc_type, _exc_value, _traceback):
self.launcher.terminate() self.launcher.terminate()
@method()
async def health(self):
from tabby_python_client.api.v1 import health
resp = await health.asyncio(client=self.client)
return resp.to_dict()
@method() @method()
async def complete(self, language: str, prompt: str): async def complete(self, language: str, prompt: str):
from tabby_client.api.v1 import completion from tabby_python_client.api.v1 import completion
from tabby_client.models import CompletionRequest, DebugOptions, CompletionResponse, Segments from tabby_python_client.models import (
CompletionRequest,
DebugOptions,
CompletionResponse,
Segments,
)
request = CompletionRequest( request = CompletionRequest(
language=language, debug_options=DebugOptions(raw_prompt=prompt) language=language, debug_options=DebugOptions(raw_prompt=prompt)
) )
resp: CompletionResponse = await completion.asyncio(client=self.client, json_body=request) resp: CompletionResponse = await completion.asyncio(
client=self.client, json_body=request
)
return resp.choices[0].text return resp.choices[0].text
@ -96,12 +107,14 @@ def main():
import json import json
model = Model() model = Model()
print(model.health.remote())
with open("./output.jsonl", "w") as fout: with open("./output.jsonl", "w") as fout:
with open("./sample.jsonl") as fin: with open("./sample.jsonl") as fin:
for line in fin: for line in fin:
x = json.loads(line) x = json.loads(line)
prompt = x['crossfile_context']['text'] + x['prompt'] prompt = x["crossfile_context"]["text"] + x["prompt"]
label = x['groundtruth'] label = x["groundtruth"]
prediction = model.complete.remote("python", prompt) prediction = model.complete.remote("python", prompt)
json.dump(dict(prompt=prompt, label=label, prediction=prediction), fout) json.dump(dict(prompt=prompt, label=label, prediction=prediction), fout)