Remove FLAGS_enable_meilisearch and FLAGS_rewrite_prompt_with_search_snippet (#122)
parent
5710994bca
commit
d452488c4b
|
|
@ -61,12 +61,6 @@ RUN <<EOF
|
|||
rm caddy.tar.gz README.md LICENSE
|
||||
EOF
|
||||
|
||||
# Install meilisearch
|
||||
RUN <<EOF
|
||||
curl -L https://install.meilisearch.com | bash
|
||||
mv meilisearch ~/.bin/
|
||||
EOF
|
||||
|
||||
# Setup file permissions
|
||||
USER root
|
||||
RUN mkdir -p /var/lib/vector
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
import os
|
||||
|
||||
import streamlit as st
|
||||
from utils.service_info import ServiceInfo
|
||||
from utils.streamlit import set_page_config
|
||||
|
|
@ -11,11 +9,6 @@ SERVICES = [
|
|||
ServiceInfo(label="dagu", health_url="http://localhost:8083"),
|
||||
]
|
||||
|
||||
if os.environ.get("FLAGS_enable_meilisearch", False):
|
||||
SERVICES.append(
|
||||
ServiceInfo(label="meilisearch", health_url="http://localhost:8084")
|
||||
)
|
||||
|
||||
|
||||
def make_badge_markdown(x: ServiceInfo):
|
||||
return f""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,3 @@
|
|||
### Experimental feature flags ###
|
||||
# export FLAGS_enable_meilisearch="1"
|
||||
# export FLAGS_rewrite_prompt_with_search_snippet="1"
|
||||
|
||||
### Released feature flags ###
|
||||
|
|
|
|||
|
|
@ -59,17 +59,6 @@ command=caddy run --config tabby/config/Caddyfile $CADDY_ARGS
|
|||
EOF
|
||||
}
|
||||
|
||||
program:meilisearch() {
|
||||
local MEILI_DIR="$DATA_DIR/meili"
|
||||
|
||||
if [[ ! -z ${FLAGS_enable_meilisearch} ]]; then
|
||||
cat <<EOF
|
||||
[program:meilisearch]
|
||||
command=meilisearch --http-addr 0.0.0.0:8084 --db-path ${MEILI_DIR}/data.ms --dump-dir ${MEILI_DIR}/dumps/
|
||||
EOF
|
||||
fi
|
||||
}
|
||||
|
||||
supervisor() {
|
||||
# Create logs dir if not exists.
|
||||
mkdir -p ${LOGS_DIR}
|
||||
|
|
@ -97,8 +86,6 @@ command=dagu server --host 0.0.0.0 --port 8083
|
|||
$(program:triton)
|
||||
|
||||
$(program:caddy)
|
||||
|
||||
$(program:meilisearch)
|
||||
EOF
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,74 +0,0 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
import meilisearch
|
||||
from loguru import logger
|
||||
|
||||
from .language_presets import LanguagePreset
|
||||
|
||||
FLAGS_enable_meilisearch = os.environ.get("FLAGS_enable_meilisearch", None)
|
||||
|
||||
|
||||
class PromptRewriter:
|
||||
def __init__(self, meili_addr: str = "http://localhost:8084"):
|
||||
self.meili_client = meilisearch.Client(meili_addr)
|
||||
|
||||
def create_query(self, preset: LanguagePreset, prompt: str):
|
||||
# Remove all punctuations and create tokens.
|
||||
tokens = re.sub(r"[^\w\s]", " ", prompt.lower()).split()
|
||||
|
||||
# Remove short tokens.
|
||||
tokens = [x for x in tokens if len(x) >= 3]
|
||||
|
||||
# Remove tokens in language reserved_keywords.
|
||||
tokens = set([x for x in tokens if x not in preset.reserved_keywords])
|
||||
|
||||
if len(tokens) > 3:
|
||||
return " ".join(tokens)
|
||||
else:
|
||||
raise PromptRewriteFailed("Too few tokens extracted from prompt")
|
||||
|
||||
def rewrite(self, preset: LanguagePreset, prompt: str) -> str:
|
||||
if preset.reserved_keywords is None:
|
||||
raise PromptRewriteFailed("Rewrite requires language keywords list")
|
||||
|
||||
index = self.meili_client.index("dataset")
|
||||
query = self.create_query(preset, prompt)
|
||||
logger.debug("query: {}", query)
|
||||
search_results = index.search(
|
||||
query,
|
||||
{
|
||||
"limit": 3,
|
||||
"attributesToCrop": ["content"],
|
||||
"cropLength": 32,
|
||||
"cropMarker": "",
|
||||
"attributesToRetrieve": ["content"],
|
||||
},
|
||||
)
|
||||
|
||||
if len(search_results["hits"]) == 0:
|
||||
raise PromptRewriteFailed("No related snippets")
|
||||
|
||||
def make_snippet(i, content):
|
||||
content = content["_formatted"]["content"]
|
||||
return f"== snippet {i+1} ==\n{content}"
|
||||
|
||||
snippets = "\n".join(
|
||||
[make_snippet(i, x) for i, x in enumerate(search_results["hits"])]
|
||||
)
|
||||
prompt = f"""Given following relevant code snippet, generate code completion based on context.
|
||||
{snippets}
|
||||
== context ==
|
||||
{prompt}"""
|
||||
logger.debug("prompt: {}", prompt)
|
||||
return prompt
|
||||
|
||||
def __call__(self, preset: LanguagePreset, prompt: str) -> str:
|
||||
try:
|
||||
return self.rewrite(preset, prompt)
|
||||
except PromptRewriteFailed:
|
||||
return prompt
|
||||
|
||||
|
||||
class PromptRewriteFailed(Exception):
|
||||
pass
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
|
|
@ -9,13 +8,8 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
|||
|
||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||
from .language_presets import LanguagePresets
|
||||
from .prompt_rewriter import PromptRewriter
|
||||
from .utils import random_completion_id, trim_with_stop_words
|
||||
|
||||
FLAGS_rewrite_prompt_with_search_snippet = os.environ.get(
|
||||
"FLAGS_rewrite_prompt_with_search_snippet", None
|
||||
)
|
||||
|
||||
|
||||
class TritonService:
|
||||
def __init__(
|
||||
|
|
@ -30,11 +24,6 @@ class TritonService:
|
|||
url=f"{host}:{port}", verbose=verbose
|
||||
)
|
||||
|
||||
if FLAGS_rewrite_prompt_with_search_snippet:
|
||||
self.rewriter = PromptRewriter()
|
||||
else:
|
||||
self.rewriter = None
|
||||
|
||||
def generate(self, data: CompletionRequest) -> List[Choice]:
|
||||
n = 1
|
||||
np_type = np.uint32
|
||||
|
|
@ -44,10 +33,7 @@ class TritonService:
|
|||
if preset is None:
|
||||
return []
|
||||
|
||||
if self.rewriter:
|
||||
prompt = self.rewriter(preset, data.prompt)
|
||||
else:
|
||||
prompt = data.prompt
|
||||
prompt = data.prompt
|
||||
|
||||
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
||||
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ env:
|
|||
- GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR"
|
||||
- DATASET_DIR: "$DATASET_DIR"
|
||||
- HOME: "$HOME"
|
||||
- FLAGS_enable_meilisearch: "$FLAGS_enable_meilisearch"
|
||||
steps:
|
||||
- name: update repositories
|
||||
dir: $APP_DIR
|
||||
|
|
@ -18,14 +17,3 @@ steps:
|
|||
command: python -m tabby.tools.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
|
||||
depends:
|
||||
- update repositories
|
||||
|
||||
- name: refresh index
|
||||
dir: $APP_DIR
|
||||
preconditions:
|
||||
- condition: "$FLAGS_enable_meilisearch"
|
||||
expected: "1"
|
||||
depends:
|
||||
- generate dataset
|
||||
command: |
|
||||
curl -X DELETE 'http://localhost:8084/indexes/dataset/documents'
|
||||
curl -X POST 'http://localhost:8084/indexes/dataset/documents?primaryKey=id' -H 'Content-Type: application/x-ndjson' --data-binary @$DATASET_DIR/dumps.json
|
||||
|
|
|
|||
Loading…
Reference in New Issue