diff --git a/development/Dockerfile b/development/Dockerfile index 8f26cb8..cefa0ef 100644 --- a/development/Dockerfile +++ b/development/Dockerfile @@ -1 +1,3 @@ FROM tabbyml/tabby + +ARG PYPI_INDEX_URL=https://pypi.org/simple diff --git a/development/docker-compose.dev.yml b/development/docker-compose.dev.yml index 2ca9722..ed9d046 100644 --- a/development/docker-compose.dev.yml +++ b/development/docker-compose.dev.yml @@ -8,7 +8,6 @@ services: dockerfile: ./development/Dockerfile args: PYPI_INDEX_URL: https://mirrors.aliyun.com/pypi/simple - PYTHON_BUILD_MIRROR_URL: https://repo.huaweicloud.com/python environment: UVICORN_RELOAD: true VECTOR_WATCH_CONFIG: true diff --git a/poetry.lock b/poetry.lock index 6750990..ab3cb34 100644 --- a/poetry.lock +++ b/poetry.lock @@ -356,6 +356,24 @@ files = [ {file = "cachetools-5.3.0.tar.gz", hash = "sha256:13dfddc7b8df938c21a940dfa6557ce6e94a2f1cdfa58eb90c805721d58f2c14"}, ] +[[package]] +name = "camel-converter" +version = "3.0.0" +description = "Converts a string from snake case to camel case or camel case to snake case" +category = "main" +optional = false +python-versions = ">=3.7,<4.0" +files = [ + {file = "camel_converter-3.0.0-py3-none-any.whl", hash = "sha256:4b01725c8ccf918752436a8aab595fa153c5123c147225434bf1f40041acb3c7"}, + {file = "camel_converter-3.0.0.tar.gz", hash = "sha256:7f200e1d1067245f39ae2df6c547dc3de8619060012679702f6471187280e6eb"}, +] + +[package.dependencies] +pydantic = {version = ">=1.8.2", optional = true, markers = "extra == \"pydantic\""} + +[package.extras] +pydantic = ["pydantic (>=1.8.2)"] + [[package]] name = "certifi" version = "2022.12.7" @@ -1514,6 +1532,22 @@ files = [ {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, ] +[[package]] +name = "meilisearch" +version = "0.26.0" +description = "The python client for Meilisearch API." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "meilisearch-0.26.0-py3-none-any.whl", hash = "sha256:e35ad840a42554b575ddb72c6aa5d9f7202ab2a6777fada5752269fe453379f4"}, + {file = "meilisearch-0.26.0.tar.gz", hash = "sha256:597ec5eaf6b726428c0e04ecf05af19379aa097093225dab0c6bfb892d03a90d"}, +] + +[package.dependencies] +camel-converter = {version = "*", extras = ["pydantic"]} +requests = "*" + [[package]] name = "mpmath" version = "1.3.0" @@ -3547,4 +3581,4 @@ testing = ["coverage (>=5.0.3)", "zope.event", "zope.testing"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "df2b6abe8684f976034d432760b5475e48695b2d2557ca4ddf32cf591c5d4475" +content-hash = "8ec52f1fb00d48606fd7784f29649d7b9288c2637c6abdd222d331ed61440d96" diff --git a/pyproject.toml b/pyproject.toml index 271d63c..f9ac3f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ peft = {git = "https://github.com/huggingface/peft.git", rev = "v0.2.0"} duckdb = "^0.7.1" torch = "^2.0.0" bitsandbytes = "^0.37.2" +meilisearch = "^0.26.0" [tool.poetry.group.dev.dependencies] pre-commit = "^3.1.1" diff --git a/tabby/scripts/flags.sh b/tabby/scripts/flags.sh index e638ca5..fb724fe 100644 --- a/tabby/scripts/flags.sh +++ b/tabby/scripts/flags.sh @@ -1 +1,5 @@ -export FLAGS_enable_meilisearch="" +### Experimental feature flags ### +# export FLAGS_enable_meilisearch="1" +# export FLAGS_rewrite_prompt_with_search_snippet="1" + +### Released feature flags ### diff --git a/tabby/server/backend/language_presets.py b/tabby/server/backend/language_presets.py index 04fa844..b6f9408 100644 --- a/tabby/server/backend/language_presets.py +++ b/tabby/server/backend/language_presets.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional, Set from pydantic import BaseModel, Field @@ -8,6 +8,7 @@ from ..models import Language class LanguagePreset(BaseModel): max_length: int stop_words: List[str] + reserved_keywords: Optional[Set] LanguagePresets = { @@ -18,6 +19,45 @@ LanguagePresets = { Language.PYTHON: LanguagePreset( max_length=128, stop_words=["\n\n", "\ndef", "\n#", "\nimport", "\nfrom", "\nclass"], + reserved_keywords=set( + [ + "False", + "class", + "from", + "or", + "None", + "continue", + "global", + "pass", + "True", + "def", + "if", + "raise", + "and", + "del", + "import", + "return", + "as", + "elif", + "in", + "try", + "assert", + "else", + "is", + "while", + "async", + "except", + "lambda", + "with", + "await", + "finally", + "nonlocal", + "yield", + "break", + "for", + "not", + ] + ), ), Language.JAVASCRIPT: LanguagePreset( max_length=128, stop_words=["\n\n", "\nfunction", "\n//", "\nimport", "\nclass"] diff --git a/tabby/server/backend/prompt_rewriter.py b/tabby/server/backend/prompt_rewriter.py new file mode 100644 index 0000000..f36653b --- /dev/null +++ b/tabby/server/backend/prompt_rewriter.py @@ -0,0 +1,74 @@ +import os +import re + +import meilisearch +from loguru import logger + +from .language_presets import LanguagePreset + +FLAGS_enable_meilisearch = os.environ.get("FLAGS_enable_meilisearch", None) + + +class PromptRewriter: + def __init__(self, meili_addr: str = "http://localhost:8084"): + self.meili_client = meilisearch.Client(meili_addr) + + def create_query(self, preset: LanguagePreset, prompt: str): + # Remove all punctuations and create tokens. + tokens = re.sub(r"[^\w\s]", " ", prompt.lower()).split() + + # Remove short tokens. + tokens = [x for x in tokens if len(x) >= 3] + + # Remove tokens in language reserved_keywords. + tokens = set([x for x in tokens if x not in preset.reserved_keywords]) + + if len(tokens) > 3: + return " ".join(tokens) + else: + raise PromptRewriteFailed("Too few tokens extracted from prompt") + + def rewrite(self, preset: LanguagePreset, prompt: str) -> str: + if preset.reserved_keywords is None: + raise PromptRewriteFailed("Rewrite requires language keywords list") + + index = self.meili_client.index("dataset") + query = self.create_query(preset, prompt) + logger.debug("query: {}", query) + search_results = index.search( + query, + { + "limit": 3, + "attributesToCrop": ["content"], + "cropLength": 32, + "cropMarker": "", + "attributesToRetrieve": ["content"], + }, + ) + + if len(search_results["hits"]) == 0: + raise PromptRewriteFailed("No related snippets") + + def make_snippet(i, content): + content = content["_formatted"]["content"] + return f"== snippet {i+1} ==\n{content}" + + snippets = "\n".join( + [make_snippet(i, x) for i, x in enumerate(search_results["hits"])] + ) + prompt = f"""Given following relevant code snippet, generate code completion based on context. +{snippets} +== context == +{prompt}""" + logger.debug("prompt: {}", prompt) + return prompt + + def __call__(self, preset: LanguagePreset, prompt: str) -> str: + try: + return self.rewrite(preset, prompt) + except PromptRewriteFailed: + return prompt + + +class PromptRewriteFailed(Exception): + pass diff --git a/tabby/server/backend/triton.py b/tabby/server/backend/triton.py index 8a4c0fb..fb104b6 100644 --- a/tabby/server/backend/triton.py +++ b/tabby/server/backend/triton.py @@ -1,3 +1,4 @@ +import os import time from typing import List @@ -8,8 +9,13 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype from ..models import Choice, CompletionRequest, CompletionResponse from .language_presets import LanguagePresets +from .prompt_rewriter import PromptRewriter from .utils import random_completion_id, trim_with_stop_words +FLAGS_rewrite_prompt_with_search_snippet = os.environ.get( + "FLAGS_rewrite_prompt_with_search_snippet", None +) + class TritonService: def __init__( @@ -24,6 +30,11 @@ class TritonService: url=f"{host}:{port}", verbose=verbose ) + if FLAGS_rewrite_prompt_with_search_snippet: + self.rewriter = PromptRewriter() + else: + self.rewriter = None + def generate(self, data: CompletionRequest) -> List[Choice]: n = 1 np_type = np.uint32 @@ -31,7 +42,11 @@ class TritonService: preset = LanguagePresets[data.language] - prompt = data.prompt + if self.rewriter: + prompt = self.rewriter(preset, data.prompt) + else: + prompt = data.prompt + input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0) input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type) prompt_len = input_start_ids.shape[1] diff --git a/tabby/server/events.py b/tabby/server/events.py index 2076994..20bde4a 100644 --- a/tabby/server/events.py +++ b/tabby/server/events.py @@ -14,7 +14,6 @@ def setup_logging(logdir): pass # Remove default handler - logger.remove() logger.add( os.path.join(logdir, "events.{time}.log"), rotation="1 hours",