feat: FLAGS_rewrite_prompt_with_search_snippet (#98)

* feat: FLAGS_rewrite_prompt_with_search_snippet

* cleanup
add-tracing
Meng Zhang 2023-04-13 15:02:12 +08:00 committed by GitHub
parent 394bfd50e0
commit 3a85b94bcc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 174 additions and 6 deletions

View File

@ -1 +1,3 @@
FROM tabbyml/tabby
ARG PYPI_INDEX_URL=https://pypi.org/simple

View File

@ -8,7 +8,6 @@ services:
dockerfile: ./development/Dockerfile
args:
PYPI_INDEX_URL: https://mirrors.aliyun.com/pypi/simple
PYTHON_BUILD_MIRROR_URL: https://repo.huaweicloud.com/python
environment:
UVICORN_RELOAD: true
VECTOR_WATCH_CONFIG: true

36
poetry.lock generated
View File

@ -356,6 +356,24 @@ files = [
{file = "cachetools-5.3.0.tar.gz", hash = "sha256:13dfddc7b8df938c21a940dfa6557ce6e94a2f1cdfa58eb90c805721d58f2c14"},
]
[[package]]
name = "camel-converter"
version = "3.0.0"
description = "Converts a string from snake case to camel case or camel case to snake case"
category = "main"
optional = false
python-versions = ">=3.7,<4.0"
files = [
{file = "camel_converter-3.0.0-py3-none-any.whl", hash = "sha256:4b01725c8ccf918752436a8aab595fa153c5123c147225434bf1f40041acb3c7"},
{file = "camel_converter-3.0.0.tar.gz", hash = "sha256:7f200e1d1067245f39ae2df6c547dc3de8619060012679702f6471187280e6eb"},
]
[package.dependencies]
pydantic = {version = ">=1.8.2", optional = true, markers = "extra == \"pydantic\""}
[package.extras]
pydantic = ["pydantic (>=1.8.2)"]
[[package]]
name = "certifi"
version = "2022.12.7"
@ -1514,6 +1532,22 @@ files = [
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
]
[[package]]
name = "meilisearch"
version = "0.26.0"
description = "The python client for Meilisearch API."
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "meilisearch-0.26.0-py3-none-any.whl", hash = "sha256:e35ad840a42554b575ddb72c6aa5d9f7202ab2a6777fada5752269fe453379f4"},
{file = "meilisearch-0.26.0.tar.gz", hash = "sha256:597ec5eaf6b726428c0e04ecf05af19379aa097093225dab0c6bfb892d03a90d"},
]
[package.dependencies]
camel-converter = {version = "*", extras = ["pydantic"]}
requests = "*"
[[package]]
name = "mpmath"
version = "1.3.0"
@ -3547,4 +3581,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
content-hash = "df2b6abe8684f976034d432760b5475e48695b2d2557ca4ddf32cf591c5d4475"
content-hash = "8ec52f1fb00d48606fd7784f29649d7b9288c2637c6abdd222d331ed61440d96"

View File

@ -22,6 +22,7 @@ peft = {git = "https://github.com/huggingface/peft.git", rev = "v0.2.0"}
duckdb = "^0.7.1"
torch = "^2.0.0"
bitsandbytes = "^0.37.2"
meilisearch = "^0.26.0"
[tool.poetry.group.dev.dependencies]
pre-commit = "^3.1.1"

View File

@ -1 +1,5 @@
export FLAGS_enable_meilisearch=""
### Experimental feature flags ###
# export FLAGS_enable_meilisearch="1"
# export FLAGS_rewrite_prompt_with_search_snippet="1"
### Released feature flags ###

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Optional, Set
from pydantic import BaseModel, Field
@ -8,6 +8,7 @@ from ..models import Language
class LanguagePreset(BaseModel):
max_length: int
stop_words: List[str]
reserved_keywords: Optional[Set]
LanguagePresets = {
@ -18,6 +19,45 @@ LanguagePresets = {
Language.PYTHON: LanguagePreset(
max_length=128,
stop_words=["\n\n", "\ndef", "\n#", "\nimport", "\nfrom", "\nclass"],
reserved_keywords=set(
[
"False",
"class",
"from",
"or",
"None",
"continue",
"global",
"pass",
"True",
"def",
"if",
"raise",
"and",
"del",
"import",
"return",
"as",
"elif",
"in",
"try",
"assert",
"else",
"is",
"while",
"async",
"except",
"lambda",
"with",
"await",
"finally",
"nonlocal",
"yield",
"break",
"for",
"not",
]
),
),
Language.JAVASCRIPT: LanguagePreset(
max_length=128, stop_words=["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"]

View File

@ -0,0 +1,74 @@
import os
import re
import meilisearch
from loguru import logger
from .language_presets import LanguagePreset
FLAGS_enable_meilisearch = os.environ.get("FLAGS_enable_meilisearch", None)
class PromptRewriter:
def __init__(self, meili_addr: str = "http://localhost:8084"):
self.meili_client = meilisearch.Client(meili_addr)
def create_query(self, preset: LanguagePreset, prompt: str):
# Remove all punctuations and create tokens.
tokens = re.sub(r"[^\w\s]", " ", prompt.lower()).split()
# Remove short tokens.
tokens = [x for x in tokens if len(x) >= 3]
# Remove tokens in language reserved_keywords.
tokens = set([x for x in tokens if x not in preset.reserved_keywords])
if len(tokens) > 3:
return " ".join(tokens)
else:
raise PromptRewriteFailed("Too few tokens extracted from prompt")
def rewrite(self, preset: LanguagePreset, prompt: str) -> str:
if preset.reserved_keywords is None:
raise PromptRewriteFailed("Rewrite requires language keywords list")
index = self.meili_client.index("dataset")
query = self.create_query(preset, prompt)
logger.debug("query: {}", query)
search_results = index.search(
query,
{
"limit": 3,
"attributesToCrop": ["content"],
"cropLength": 32,
"cropMarker": "",
"attributesToRetrieve": ["content"],
},
)
if len(search_results["hits"]) == 0:
raise PromptRewriteFailed("No related snippets")
def make_snippet(i, content):
content = content["_formatted"]["content"]
return f"== snippet {i+1} ==\n{content}"
snippets = "\n".join(
[make_snippet(i, x) for i, x in enumerate(search_results["hits"])]
)
prompt = f"""Given following relevant code snippet, generate code completion based on context.
{snippets}
== context ==
{prompt}"""
logger.debug("prompt: {}", prompt)
return prompt
def __call__(self, preset: LanguagePreset, prompt: str) -> str:
try:
return self.rewrite(preset, prompt)
except PromptRewriteFailed:
return prompt
class PromptRewriteFailed(Exception):
pass

View File

@ -1,3 +1,4 @@
import os
import time
from typing import List
@ -8,8 +9,13 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
from ..models import Choice, CompletionRequest, CompletionResponse
from .language_presets import LanguagePresets
from .prompt_rewriter import PromptRewriter
from .utils import random_completion_id, trim_with_stop_words
FLAGS_rewrite_prompt_with_search_snippet = os.environ.get(
"FLAGS_rewrite_prompt_with_search_snippet", None
)
class TritonService:
def __init__(
@ -24,6 +30,11 @@ class TritonService:
url=f"{host}:{port}", verbose=verbose
)
if FLAGS_rewrite_prompt_with_search_snippet:
self.rewriter = PromptRewriter()
else:
self.rewriter = None
def generate(self, data: CompletionRequest) -> List[Choice]:
n = 1
np_type = np.uint32
@ -31,7 +42,11 @@ class TritonService:
preset = LanguagePresets[data.language]
prompt = data.prompt
if self.rewriter:
prompt = self.rewriter(preset, data.prompt)
else:
prompt = data.prompt
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
prompt_len = input_start_ids.shape[1]

View File

@ -14,7 +14,6 @@ def setup_logging(logdir):
pass
# Remove default handler
logger.remove()
logger.add(
os.path.join(logdir, "events.{time}.log"),
rotation="1 hours",