feat: FLAGS_rewrite_prompt_with_search_snippet (#98)
* feat: FLAGS_rewrite_prompt_with_search_snippet * cleanupadd-tracing
parent
394bfd50e0
commit
3a85b94bcc
|
|
@ -1 +1,3 @@
|
||||||
FROM tabbyml/tabby
|
FROM tabbyml/tabby
|
||||||
|
|
||||||
|
ARG PYPI_INDEX_URL=https://pypi.org/simple
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ services:
|
||||||
dockerfile: ./development/Dockerfile
|
dockerfile: ./development/Dockerfile
|
||||||
args:
|
args:
|
||||||
PYPI_INDEX_URL: https://mirrors.aliyun.com/pypi/simple
|
PYPI_INDEX_URL: https://mirrors.aliyun.com/pypi/simple
|
||||||
PYTHON_BUILD_MIRROR_URL: https://repo.huaweicloud.com/python
|
|
||||||
environment:
|
environment:
|
||||||
UVICORN_RELOAD: true
|
UVICORN_RELOAD: true
|
||||||
VECTOR_WATCH_CONFIG: true
|
VECTOR_WATCH_CONFIG: true
|
||||||
|
|
|
||||||
|
|
@ -356,6 +356,24 @@ files = [
|
||||||
{file = "cachetools-5.3.0.tar.gz", hash = "sha256:13dfddc7b8df938c21a940dfa6557ce6e94a2f1cdfa58eb90c805721d58f2c14"},
|
{file = "cachetools-5.3.0.tar.gz", hash = "sha256:13dfddc7b8df938c21a940dfa6557ce6e94a2f1cdfa58eb90c805721d58f2c14"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "camel-converter"
|
||||||
|
version = "3.0.0"
|
||||||
|
description = "Converts a string from snake case to camel case or camel case to snake case"
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7,<4.0"
|
||||||
|
files = [
|
||||||
|
{file = "camel_converter-3.0.0-py3-none-any.whl", hash = "sha256:4b01725c8ccf918752436a8aab595fa153c5123c147225434bf1f40041acb3c7"},
|
||||||
|
{file = "camel_converter-3.0.0.tar.gz", hash = "sha256:7f200e1d1067245f39ae2df6c547dc3de8619060012679702f6471187280e6eb"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
pydantic = {version = ">=1.8.2", optional = true, markers = "extra == \"pydantic\""}
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
pydantic = ["pydantic (>=1.8.2)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "certifi"
|
name = "certifi"
|
||||||
version = "2022.12.7"
|
version = "2022.12.7"
|
||||||
|
|
@ -1514,6 +1532,22 @@ files = [
|
||||||
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
{file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "meilisearch"
|
||||||
|
version = "0.26.0"
|
||||||
|
description = "The python client for Meilisearch API."
|
||||||
|
category = "main"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.7"
|
||||||
|
files = [
|
||||||
|
{file = "meilisearch-0.26.0-py3-none-any.whl", hash = "sha256:e35ad840a42554b575ddb72c6aa5d9f7202ab2a6777fada5752269fe453379f4"},
|
||||||
|
{file = "meilisearch-0.26.0.tar.gz", hash = "sha256:597ec5eaf6b726428c0e04ecf05af19379aa097093225dab0c6bfb892d03a90d"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.dependencies]
|
||||||
|
camel-converter = {version = "*", extras = ["pydantic"]}
|
||||||
|
requests = "*"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mpmath"
|
name = "mpmath"
|
||||||
version = "1.3.0"
|
version = "1.3.0"
|
||||||
|
|
@ -3547,4 +3581,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "df2b6abe8684f976034d432760b5475e48695b2d2557ca4ddf32cf591c5d4475"
|
content-hash = "8ec52f1fb00d48606fd7784f29649d7b9288c2637c6abdd222d331ed61440d96"
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ peft = {git = "https://github.com/huggingface/peft.git", rev = "v0.2.0"}
|
||||||
duckdb = "^0.7.1"
|
duckdb = "^0.7.1"
|
||||||
torch = "^2.0.0"
|
torch = "^2.0.0"
|
||||||
bitsandbytes = "^0.37.2"
|
bitsandbytes = "^0.37.2"
|
||||||
|
meilisearch = "^0.26.0"
|
||||||
|
|
||||||
[tool.poetry.group.dev.dependencies]
|
[tool.poetry.group.dev.dependencies]
|
||||||
pre-commit = "^3.1.1"
|
pre-commit = "^3.1.1"
|
||||||
|
|
|
||||||
|
|
@ -1 +1,5 @@
|
||||||
export FLAGS_enable_meilisearch=""
|
### Experimental feature flags ###
|
||||||
|
# export FLAGS_enable_meilisearch="1"
|
||||||
|
# export FLAGS_rewrite_prompt_with_search_snippet="1"
|
||||||
|
|
||||||
|
### Released feature flags ###
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List
|
from typing import List, Optional, Set
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
@ -8,6 +8,7 @@ from ..models import Language
|
||||||
class LanguagePreset(BaseModel):
|
class LanguagePreset(BaseModel):
|
||||||
max_length: int
|
max_length: int
|
||||||
stop_words: List[str]
|
stop_words: List[str]
|
||||||
|
reserved_keywords: Optional[Set]
|
||||||
|
|
||||||
|
|
||||||
LanguagePresets = {
|
LanguagePresets = {
|
||||||
|
|
@ -18,6 +19,45 @@ LanguagePresets = {
|
||||||
Language.PYTHON: LanguagePreset(
|
Language.PYTHON: LanguagePreset(
|
||||||
max_length=128,
|
max_length=128,
|
||||||
stop_words=["\n\n", "\ndef", "\n#", "\nimport", "\nfrom", "\nclass"],
|
stop_words=["\n\n", "\ndef", "\n#", "\nimport", "\nfrom", "\nclass"],
|
||||||
|
reserved_keywords=set(
|
||||||
|
[
|
||||||
|
"False",
|
||||||
|
"class",
|
||||||
|
"from",
|
||||||
|
"or",
|
||||||
|
"None",
|
||||||
|
"continue",
|
||||||
|
"global",
|
||||||
|
"pass",
|
||||||
|
"True",
|
||||||
|
"def",
|
||||||
|
"if",
|
||||||
|
"raise",
|
||||||
|
"and",
|
||||||
|
"del",
|
||||||
|
"import",
|
||||||
|
"return",
|
||||||
|
"as",
|
||||||
|
"elif",
|
||||||
|
"in",
|
||||||
|
"try",
|
||||||
|
"assert",
|
||||||
|
"else",
|
||||||
|
"is",
|
||||||
|
"while",
|
||||||
|
"async",
|
||||||
|
"except",
|
||||||
|
"lambda",
|
||||||
|
"with",
|
||||||
|
"await",
|
||||||
|
"finally",
|
||||||
|
"nonlocal",
|
||||||
|
"yield",
|
||||||
|
"break",
|
||||||
|
"for",
|
||||||
|
"not",
|
||||||
|
]
|
||||||
|
),
|
||||||
),
|
),
|
||||||
Language.JAVASCRIPT: LanguagePreset(
|
Language.JAVASCRIPT: LanguagePreset(
|
||||||
max_length=128, stop_words=["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"]
|
max_length=128, stop_words=["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"]
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
|
||||||
|
import meilisearch
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from .language_presets import LanguagePreset
|
||||||
|
|
||||||
|
FLAGS_enable_meilisearch = os.environ.get("FLAGS_enable_meilisearch", None)
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRewriter:
|
||||||
|
def __init__(self, meili_addr: str = "http://localhost:8084"):
|
||||||
|
self.meili_client = meilisearch.Client(meili_addr)
|
||||||
|
|
||||||
|
def create_query(self, preset: LanguagePreset, prompt: str):
|
||||||
|
# Remove all punctuations and create tokens.
|
||||||
|
tokens = re.sub(r"[^\w\s]", " ", prompt.lower()).split()
|
||||||
|
|
||||||
|
# Remove short tokens.
|
||||||
|
tokens = [x for x in tokens if len(x) >= 3]
|
||||||
|
|
||||||
|
# Remove tokens in language reserved_keywords.
|
||||||
|
tokens = set([x for x in tokens if x not in preset.reserved_keywords])
|
||||||
|
|
||||||
|
if len(tokens) > 3:
|
||||||
|
return " ".join(tokens)
|
||||||
|
else:
|
||||||
|
raise PromptRewriteFailed("Too few tokens extracted from prompt")
|
||||||
|
|
||||||
|
def rewrite(self, preset: LanguagePreset, prompt: str) -> str:
|
||||||
|
if preset.reserved_keywords is None:
|
||||||
|
raise PromptRewriteFailed("Rewrite requires language keywords list")
|
||||||
|
|
||||||
|
index = self.meili_client.index("dataset")
|
||||||
|
query = self.create_query(preset, prompt)
|
||||||
|
logger.debug("query: {}", query)
|
||||||
|
search_results = index.search(
|
||||||
|
query,
|
||||||
|
{
|
||||||
|
"limit": 3,
|
||||||
|
"attributesToCrop": ["content"],
|
||||||
|
"cropLength": 32,
|
||||||
|
"cropMarker": "",
|
||||||
|
"attributesToRetrieve": ["content"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(search_results["hits"]) == 0:
|
||||||
|
raise PromptRewriteFailed("No related snippets")
|
||||||
|
|
||||||
|
def make_snippet(i, content):
|
||||||
|
content = content["_formatted"]["content"]
|
||||||
|
return f"== snippet {i+1} ==\n{content}"
|
||||||
|
|
||||||
|
snippets = "\n".join(
|
||||||
|
[make_snippet(i, x) for i, x in enumerate(search_results["hits"])]
|
||||||
|
)
|
||||||
|
prompt = f"""Given following relevant code snippet, generate code completion based on context.
|
||||||
|
{snippets}
|
||||||
|
== context ==
|
||||||
|
{prompt}"""
|
||||||
|
logger.debug("prompt: {}", prompt)
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def __call__(self, preset: LanguagePreset, prompt: str) -> str:
|
||||||
|
try:
|
||||||
|
return self.rewrite(preset, prompt)
|
||||||
|
except PromptRewriteFailed:
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class PromptRewriteFailed(Exception):
|
||||||
|
pass
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
@ -8,8 +9,13 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
||||||
|
|
||||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||||
from .language_presets import LanguagePresets
|
from .language_presets import LanguagePresets
|
||||||
|
from .prompt_rewriter import PromptRewriter
|
||||||
from .utils import random_completion_id, trim_with_stop_words
|
from .utils import random_completion_id, trim_with_stop_words
|
||||||
|
|
||||||
|
FLAGS_rewrite_prompt_with_search_snippet = os.environ.get(
|
||||||
|
"FLAGS_rewrite_prompt_with_search_snippet", None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TritonService:
|
class TritonService:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -24,6 +30,11 @@ class TritonService:
|
||||||
url=f"{host}:{port}", verbose=verbose
|
url=f"{host}:{port}", verbose=verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if FLAGS_rewrite_prompt_with_search_snippet:
|
||||||
|
self.rewriter = PromptRewriter()
|
||||||
|
else:
|
||||||
|
self.rewriter = None
|
||||||
|
|
||||||
def generate(self, data: CompletionRequest) -> List[Choice]:
|
def generate(self, data: CompletionRequest) -> List[Choice]:
|
||||||
n = 1
|
n = 1
|
||||||
np_type = np.uint32
|
np_type = np.uint32
|
||||||
|
|
@ -31,7 +42,11 @@ class TritonService:
|
||||||
|
|
||||||
preset = LanguagePresets[data.language]
|
preset = LanguagePresets[data.language]
|
||||||
|
|
||||||
prompt = data.prompt
|
if self.rewriter:
|
||||||
|
prompt = self.rewriter(preset, data.prompt)
|
||||||
|
else:
|
||||||
|
prompt = data.prompt
|
||||||
|
|
||||||
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
||||||
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
||||||
prompt_len = input_start_ids.shape[1]
|
prompt_len = input_start_ids.shape[1]
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,6 @@ def setup_logging(logdir):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Remove default handler
|
# Remove default handler
|
||||||
logger.remove()
|
|
||||||
logger.add(
|
logger.add(
|
||||||
os.path.join(logdir, "events.{time}.log"),
|
os.path.join(logdir, "events.{time}.log"),
|
||||||
rotation="1 hours",
|
rotation="1 hours",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue