feat: FLAGS_rewrite_prompt_with_search_snippet (#98)
* feat: FLAGS_rewrite_prompt_with_search_snippet * cleanupadd-tracing
parent
394bfd50e0
commit
3a85b94bcc
|
|
@ -1 +1,3 @@
|
|||
FROM tabbyml/tabby
|
||||
|
||||
ARG PYPI_INDEX_URL=https://pypi.org/simple
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ services:
|
|||
dockerfile: ./development/Dockerfile
|
||||
args:
|
||||
PYPI_INDEX_URL: https://mirrors.aliyun.com/pypi/simple
|
||||
PYTHON_BUILD_MIRROR_URL: https://repo.huaweicloud.com/python
|
||||
environment:
|
||||
UVICORN_RELOAD: true
|
||||
VECTOR_WATCH_CONFIG: true
|
||||
|
|
|
|||
|
|
@ -356,6 +356,24 @@ files = [
|
|||
{file = "cachetools-5.3.0.tar.gz", hash = "sha256:13dfddc7b8df938c21a940dfa6557ce6e94a2f1cdfa58eb90c805721d58f2c14"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "camel-converter"
|
||||
version = "3.0.0"
|
||||
description = "Converts a string from snake case to camel case or camel case to snake case"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<4.0"
|
||||
files = [
|
||||
{file = "camel_converter-3.0.0-py3-none-any.whl", hash = "sha256:4b01725c8ccf918752436a8aab595fa153c5123c147225434bf1f40041acb3c7"},
|
||||
{file = "camel_converter-3.0.0.tar.gz", hash = "sha256:7f200e1d1067245f39ae2df6c547dc3de8619060012679702f6471187280e6eb"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
pydantic = {version = ">=1.8.2", optional = true, markers = "extra == \"pydantic\""}
|
||||
|
||||
[package.extras]
|
||||
pydantic = ["pydantic (>=1.8.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "certifi"
|
||||
version = "2022.12.7"
|
||||
|
|
@ -1514,6 +1532,22 @@ files = [
|
|||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch"
|
||||
version = "0.26.0"
|
||||
description = "The python client for Meilisearch API."
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "meilisearch-0.26.0-py3-none-any.whl", hash = "sha256:e35ad840a42554b575ddb72c6aa5d9f7202ab2a6777fada5752269fe453379f4"},
|
||||
{file = "meilisearch-0.26.0.tar.gz", hash = "sha256:597ec5eaf6b726428c0e04ecf05af19379aa097093225dab0c6bfb892d03a90d"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
camel-converter = {version = "*", extras = ["pydantic"]}
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "mpmath"
|
||||
version = "1.3.0"
|
||||
|
|
@ -3547,4 +3581,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.10"
|
||||
content-hash = "df2b6abe8684f976034d432760b5475e48695b2d2557ca4ddf32cf591c5d4475"
|
||||
content-hash = "8ec52f1fb00d48606fd7784f29649d7b9288c2637c6abdd222d331ed61440d96"
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ peft = {git = "https://github.com/huggingface/peft.git", rev = "v0.2.0"}
|
|||
duckdb = "^0.7.1"
|
||||
torch = "^2.0.0"
|
||||
bitsandbytes = "^0.37.2"
|
||||
meilisearch = "^0.26.0"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pre-commit = "^3.1.1"
|
||||
|
|
|
|||
|
|
@ -1 +1,5 @@
|
|||
export FLAGS_enable_meilisearch=""
|
||||
### Experimental feature flags ###
|
||||
# export FLAGS_enable_meilisearch="1"
|
||||
# export FLAGS_rewrite_prompt_with_search_snippet="1"
|
||||
|
||||
### Released feature flags ###
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import List
|
||||
from typing import List, Optional, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
|
@ -8,6 +8,7 @@ from ..models import Language
|
|||
class LanguagePreset(BaseModel):
|
||||
max_length: int
|
||||
stop_words: List[str]
|
||||
reserved_keywords: Optional[Set]
|
||||
|
||||
|
||||
LanguagePresets = {
|
||||
|
|
@ -18,6 +19,45 @@ LanguagePresets = {
|
|||
Language.PYTHON: LanguagePreset(
|
||||
max_length=128,
|
||||
stop_words=["\n\n", "\ndef", "\n#", "\nimport", "\nfrom", "\nclass"],
|
||||
reserved_keywords=set(
|
||||
[
|
||||
"False",
|
||||
"class",
|
||||
"from",
|
||||
"or",
|
||||
"None",
|
||||
"continue",
|
||||
"global",
|
||||
"pass",
|
||||
"True",
|
||||
"def",
|
||||
"if",
|
||||
"raise",
|
||||
"and",
|
||||
"del",
|
||||
"import",
|
||||
"return",
|
||||
"as",
|
||||
"elif",
|
||||
"in",
|
||||
"try",
|
||||
"assert",
|
||||
"else",
|
||||
"is",
|
||||
"while",
|
||||
"async",
|
||||
"except",
|
||||
"lambda",
|
||||
"with",
|
||||
"await",
|
||||
"finally",
|
||||
"nonlocal",
|
||||
"yield",
|
||||
"break",
|
||||
"for",
|
||||
"not",
|
||||
]
|
||||
),
|
||||
),
|
||||
Language.JAVASCRIPT: LanguagePreset(
|
||||
max_length=128, stop_words=["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,74 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import meilisearch
|
||||
from loguru import logger
|
||||
|
||||
from .language_presets import LanguagePreset
|
||||
|
||||
FLAGS_enable_meilisearch = os.environ.get("FLAGS_enable_meilisearch", None)
|
||||
|
||||
|
||||
class PromptRewriter:
|
||||
def __init__(self, meili_addr: str = "http://localhost:8084"):
|
||||
self.meili_client = meilisearch.Client(meili_addr)
|
||||
|
||||
def create_query(self, preset: LanguagePreset, prompt: str):
|
||||
# Remove all punctuations and create tokens.
|
||||
tokens = re.sub(r"[^\w\s]", " ", prompt.lower()).split()
|
||||
|
||||
# Remove short tokens.
|
||||
tokens = [x for x in tokens if len(x) >= 3]
|
||||
|
||||
# Remove tokens in language reserved_keywords.
|
||||
tokens = set([x for x in tokens if x not in preset.reserved_keywords])
|
||||
|
||||
if len(tokens) > 3:
|
||||
return " ".join(tokens)
|
||||
else:
|
||||
raise PromptRewriteFailed("Too few tokens extracted from prompt")
|
||||
|
||||
def rewrite(self, preset: LanguagePreset, prompt: str) -> str:
|
||||
if preset.reserved_keywords is None:
|
||||
raise PromptRewriteFailed("Rewrite requires language keywords list")
|
||||
|
||||
index = self.meili_client.index("dataset")
|
||||
query = self.create_query(preset, prompt)
|
||||
logger.debug("query: {}", query)
|
||||
search_results = index.search(
|
||||
query,
|
||||
{
|
||||
"limit": 3,
|
||||
"attributesToCrop": ["content"],
|
||||
"cropLength": 32,
|
||||
"cropMarker": "",
|
||||
"attributesToRetrieve": ["content"],
|
||||
},
|
||||
)
|
||||
|
||||
if len(search_results["hits"]) == 0:
|
||||
raise PromptRewriteFailed("No related snippets")
|
||||
|
||||
def make_snippet(i, content):
|
||||
content = content["_formatted"]["content"]
|
||||
return f"== snippet {i+1} ==\n{content}"
|
||||
|
||||
snippets = "\n".join(
|
||||
[make_snippet(i, x) for i, x in enumerate(search_results["hits"])]
|
||||
)
|
||||
prompt = f"""Given following relevant code snippet, generate code completion based on context.
|
||||
{snippets}
|
||||
== context ==
|
||||
{prompt}"""
|
||||
logger.debug("prompt: {}", prompt)
|
||||
return prompt
|
||||
|
||||
def __call__(self, preset: LanguagePreset, prompt: str) -> str:
|
||||
try:
|
||||
return self.rewrite(preset, prompt)
|
||||
except PromptRewriteFailed:
|
||||
return prompt
|
||||
|
||||
|
||||
class PromptRewriteFailed(Exception):
|
||||
pass
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
|
|
@ -8,8 +9,13 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
|||
|
||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||
from .language_presets import LanguagePresets
|
||||
from .prompt_rewriter import PromptRewriter
|
||||
from .utils import random_completion_id, trim_with_stop_words
|
||||
|
||||
FLAGS_rewrite_prompt_with_search_snippet = os.environ.get(
|
||||
"FLAGS_rewrite_prompt_with_search_snippet", None
|
||||
)
|
||||
|
||||
|
||||
class TritonService:
|
||||
def __init__(
|
||||
|
|
@ -24,6 +30,11 @@ class TritonService:
|
|||
url=f"{host}:{port}", verbose=verbose
|
||||
)
|
||||
|
||||
if FLAGS_rewrite_prompt_with_search_snippet:
|
||||
self.rewriter = PromptRewriter()
|
||||
else:
|
||||
self.rewriter = None
|
||||
|
||||
def generate(self, data: CompletionRequest) -> List[Choice]:
|
||||
n = 1
|
||||
np_type = np.uint32
|
||||
|
|
@ -31,7 +42,11 @@ class TritonService:
|
|||
|
||||
preset = LanguagePresets[data.language]
|
||||
|
||||
if self.rewriter:
|
||||
prompt = self.rewriter(preset, data.prompt)
|
||||
else:
|
||||
prompt = data.prompt
|
||||
|
||||
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
||||
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
||||
prompt_len = input_start_ids.shape[1]
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ def setup_logging(logdir):
|
|||
pass
|
||||
|
||||
# Remove default handler
|
||||
logger.remove()
|
||||
logger.add(
|
||||
os.path.join(logdir, "events.{time}.log"),
|
||||
rotation="1 hours",
|
||||
|
|
|
|||
Loading…
Reference in New Issue