Remove FLAGS_enable_meilisearch and FLAGS_rewrite_prompt_with_search_snippet (#122)

add-tracing
Meng Zhang 2023-05-01 15:06:06 +08:00 committed by GitHub
parent 5710994bca
commit d452488c4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 1 additions and 129 deletions

View File

@ -61,12 +61,6 @@ RUN <<EOF
rm caddy.tar.gz README.md LICENSE
EOF
# Install meilisearch
RUN <<EOF
curl -L https://install.meilisearch.com | bash
mv meilisearch ~/.bin/
EOF
# Setup file permissions
USER root
RUN mkdir -p /var/lib/vector

View File

@ -1,5 +1,3 @@
import os
import streamlit as st
from utils.service_info import ServiceInfo
from utils.streamlit import set_page_config
@ -11,11 +9,6 @@ SERVICES = [
ServiceInfo(label="dagu", health_url="http://localhost:8083"),
]
if os.environ.get("FLAGS_enable_meilisearch", False):
SERVICES.append(
ServiceInfo(label="meilisearch", health_url="http://localhost:8084")
)
def make_badge_markdown(x: ServiceInfo):
return f"![{x.label}]({x.badge_url})"

View File

@ -1,5 +1,3 @@
### Experimental feature flags ###
# export FLAGS_enable_meilisearch="1"
# export FLAGS_rewrite_prompt_with_search_snippet="1"
### Released feature flags ###

View File

@ -59,17 +59,6 @@ command=caddy run --config tabby/config/Caddyfile $CADDY_ARGS
EOF
}
program:meilisearch() {
local MEILI_DIR="$DATA_DIR/meili"
if [[ ! -z ${FLAGS_enable_meilisearch} ]]; then
cat <<EOF
[program:meilisearch]
command=meilisearch --http-addr 0.0.0.0:8084 --db-path ${MEILI_DIR}/data.ms --dump-dir ${MEILI_DIR}/dumps/
EOF
fi
}
supervisor() {
# Create logs dir if not exists.
mkdir -p ${LOGS_DIR}
@ -97,8 +86,6 @@ command=dagu server --host 0.0.0.0 --port 8083
$(program:triton)
$(program:caddy)
$(program:meilisearch)
EOF
)
}

View File

@ -1,74 +0,0 @@
import os
import re
import meilisearch
from loguru import logger
from .language_presets import LanguagePreset
FLAGS_enable_meilisearch = os.environ.get("FLAGS_enable_meilisearch", None)
class PromptRewriter:
def __init__(self, meili_addr: str = "http://localhost:8084"):
self.meili_client = meilisearch.Client(meili_addr)
def create_query(self, preset: LanguagePreset, prompt: str):
# Remove all punctuations and create tokens.
tokens = re.sub(r"[^\w\s]", " ", prompt.lower()).split()
# Remove short tokens.
tokens = [x for x in tokens if len(x) >= 3]
# Remove tokens in language reserved_keywords.
tokens = set([x for x in tokens if x not in preset.reserved_keywords])
if len(tokens) > 3:
return " ".join(tokens)
else:
raise PromptRewriteFailed("Too few tokens extracted from prompt")
def rewrite(self, preset: LanguagePreset, prompt: str) -> str:
if preset.reserved_keywords is None:
raise PromptRewriteFailed("Rewrite requires language keywords list")
index = self.meili_client.index("dataset")
query = self.create_query(preset, prompt)
logger.debug("query: {}", query)
search_results = index.search(
query,
{
"limit": 3,
"attributesToCrop": ["content"],
"cropLength": 32,
"cropMarker": "",
"attributesToRetrieve": ["content"],
},
)
if len(search_results["hits"]) == 0:
raise PromptRewriteFailed("No related snippets")
def make_snippet(i, content):
content = content["_formatted"]["content"]
return f"== snippet {i+1} ==\n{content}"
snippets = "\n".join(
[make_snippet(i, x) for i, x in enumerate(search_results["hits"])]
)
prompt = f"""Given following relevant code snippet, generate code completion based on context.
{snippets}
== context ==
{prompt}"""
logger.debug("prompt: {}", prompt)
return prompt
def __call__(self, preset: LanguagePreset, prompt: str) -> str:
try:
return self.rewrite(preset, prompt)
except PromptRewriteFailed:
return prompt
class PromptRewriteFailed(Exception):
pass

View File

@ -1,4 +1,3 @@
import os
import time
from typing import List
@ -9,13 +8,8 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
from ..models import Choice, CompletionRequest, CompletionResponse
from .language_presets import LanguagePresets
from .prompt_rewriter import PromptRewriter
from .utils import random_completion_id, trim_with_stop_words
FLAGS_rewrite_prompt_with_search_snippet = os.environ.get(
"FLAGS_rewrite_prompt_with_search_snippet", None
)
class TritonService:
def __init__(
@ -30,11 +24,6 @@ class TritonService:
url=f"{host}:{port}", verbose=verbose
)
if FLAGS_rewrite_prompt_with_search_snippet:
self.rewriter = PromptRewriter()
else:
self.rewriter = None
def generate(self, data: CompletionRequest) -> List[Choice]:
n = 1
np_type = np.uint32
@ -44,10 +33,7 @@ class TritonService:
if preset is None:
return []
if self.rewriter:
prompt = self.rewriter(preset, data.prompt)
else:
prompt = data.prompt
prompt = data.prompt
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)

View File

@ -7,7 +7,6 @@ env:
- GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR"
- DATASET_DIR: "$DATASET_DIR"
- HOME: "$HOME"
- FLAGS_enable_meilisearch: "$FLAGS_enable_meilisearch"
steps:
- name: update repositories
dir: $APP_DIR
@ -18,14 +17,3 @@ steps:
command: python -m tabby.tools.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
depends:
- update repositories
- name: refresh index
dir: $APP_DIR
preconditions:
- condition: "$FLAGS_enable_meilisearch"
expected: "1"
depends:
- generate dataset
command: |
curl -X DELETE 'http://localhost:8084/indexes/dataset/documents'
curl -X POST 'http://localhost:8084/indexes/dataset/documents?primaryKey=id' -H 'Content-Type: application/x-ndjson' --data-binary @$DATASET_DIR/dumps.json