Remove FLAGS_enable_meilisearch and FLAGS_rewrite_prompt_with_search_snippet (#122)
parent
5710994bca
commit
d452488c4b
|
|
@ -61,12 +61,6 @@ RUN <<EOF
|
||||||
rm caddy.tar.gz README.md LICENSE
|
rm caddy.tar.gz README.md LICENSE
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
# Install meilisearch
|
|
||||||
RUN <<EOF
|
|
||||||
curl -L https://install.meilisearch.com | bash
|
|
||||||
mv meilisearch ~/.bin/
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# Setup file permissions
|
# Setup file permissions
|
||||||
USER root
|
USER root
|
||||||
RUN mkdir -p /var/lib/vector
|
RUN mkdir -p /var/lib/vector
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
import os
|
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from utils.service_info import ServiceInfo
|
from utils.service_info import ServiceInfo
|
||||||
from utils.streamlit import set_page_config
|
from utils.streamlit import set_page_config
|
||||||
|
|
@ -11,11 +9,6 @@ SERVICES = [
|
||||||
ServiceInfo(label="dagu", health_url="http://localhost:8083"),
|
ServiceInfo(label="dagu", health_url="http://localhost:8083"),
|
||||||
]
|
]
|
||||||
|
|
||||||
if os.environ.get("FLAGS_enable_meilisearch", False):
|
|
||||||
SERVICES.append(
|
|
||||||
ServiceInfo(label="meilisearch", health_url="http://localhost:8084")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def make_badge_markdown(x: ServiceInfo):
|
def make_badge_markdown(x: ServiceInfo):
|
||||||
return f""
|
return f""
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
### Experimental feature flags ###
|
### Experimental feature flags ###
|
||||||
# export FLAGS_enable_meilisearch="1"
|
|
||||||
# export FLAGS_rewrite_prompt_with_search_snippet="1"
|
|
||||||
|
|
||||||
### Released feature flags ###
|
### Released feature flags ###
|
||||||
|
|
|
||||||
|
|
@ -59,17 +59,6 @@ command=caddy run --config tabby/config/Caddyfile $CADDY_ARGS
|
||||||
EOF
|
EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
program:meilisearch() {
|
|
||||||
local MEILI_DIR="$DATA_DIR/meili"
|
|
||||||
|
|
||||||
if [[ ! -z ${FLAGS_enable_meilisearch} ]]; then
|
|
||||||
cat <<EOF
|
|
||||||
[program:meilisearch]
|
|
||||||
command=meilisearch --http-addr 0.0.0.0:8084 --db-path ${MEILI_DIR}/data.ms --dump-dir ${MEILI_DIR}/dumps/
|
|
||||||
EOF
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
supervisor() {
|
supervisor() {
|
||||||
# Create logs dir if not exists.
|
# Create logs dir if not exists.
|
||||||
mkdir -p ${LOGS_DIR}
|
mkdir -p ${LOGS_DIR}
|
||||||
|
|
@ -97,8 +86,6 @@ command=dagu server --host 0.0.0.0 --port 8083
|
||||||
$(program:triton)
|
$(program:triton)
|
||||||
|
|
||||||
$(program:caddy)
|
$(program:caddy)
|
||||||
|
|
||||||
$(program:meilisearch)
|
|
||||||
EOF
|
EOF
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,74 +0,0 @@
|
||||||
import os
|
|
||||||
import re
|
|
||||||
|
|
||||||
import meilisearch
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
from .language_presets import LanguagePreset
|
|
||||||
|
|
||||||
FLAGS_enable_meilisearch = os.environ.get("FLAGS_enable_meilisearch", None)
|
|
||||||
|
|
||||||
|
|
||||||
class PromptRewriter:
|
|
||||||
def __init__(self, meili_addr: str = "http://localhost:8084"):
|
|
||||||
self.meili_client = meilisearch.Client(meili_addr)
|
|
||||||
|
|
||||||
def create_query(self, preset: LanguagePreset, prompt: str):
|
|
||||||
# Remove all punctuations and create tokens.
|
|
||||||
tokens = re.sub(r"[^\w\s]", " ", prompt.lower()).split()
|
|
||||||
|
|
||||||
# Remove short tokens.
|
|
||||||
tokens = [x for x in tokens if len(x) >= 3]
|
|
||||||
|
|
||||||
# Remove tokens in language reserved_keywords.
|
|
||||||
tokens = set([x for x in tokens if x not in preset.reserved_keywords])
|
|
||||||
|
|
||||||
if len(tokens) > 3:
|
|
||||||
return " ".join(tokens)
|
|
||||||
else:
|
|
||||||
raise PromptRewriteFailed("Too few tokens extracted from prompt")
|
|
||||||
|
|
||||||
def rewrite(self, preset: LanguagePreset, prompt: str) -> str:
|
|
||||||
if preset.reserved_keywords is None:
|
|
||||||
raise PromptRewriteFailed("Rewrite requires language keywords list")
|
|
||||||
|
|
||||||
index = self.meili_client.index("dataset")
|
|
||||||
query = self.create_query(preset, prompt)
|
|
||||||
logger.debug("query: {}", query)
|
|
||||||
search_results = index.search(
|
|
||||||
query,
|
|
||||||
{
|
|
||||||
"limit": 3,
|
|
||||||
"attributesToCrop": ["content"],
|
|
||||||
"cropLength": 32,
|
|
||||||
"cropMarker": "",
|
|
||||||
"attributesToRetrieve": ["content"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(search_results["hits"]) == 0:
|
|
||||||
raise PromptRewriteFailed("No related snippets")
|
|
||||||
|
|
||||||
def make_snippet(i, content):
|
|
||||||
content = content["_formatted"]["content"]
|
|
||||||
return f"== snippet {i+1} ==\n{content}"
|
|
||||||
|
|
||||||
snippets = "\n".join(
|
|
||||||
[make_snippet(i, x) for i, x in enumerate(search_results["hits"])]
|
|
||||||
)
|
|
||||||
prompt = f"""Given following relevant code snippet, generate code completion based on context.
|
|
||||||
{snippets}
|
|
||||||
== context ==
|
|
||||||
{prompt}"""
|
|
||||||
logger.debug("prompt: {}", prompt)
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
def __call__(self, preset: LanguagePreset, prompt: str) -> str:
|
|
||||||
try:
|
|
||||||
return self.rewrite(preset, prompt)
|
|
||||||
except PromptRewriteFailed:
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
class PromptRewriteFailed(Exception):
|
|
||||||
pass
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
|
@ -9,13 +8,8 @@ from tritonclient.utils import InferenceServerException, np_to_triton_dtype
|
||||||
|
|
||||||
from ..models import Choice, CompletionRequest, CompletionResponse
|
from ..models import Choice, CompletionRequest, CompletionResponse
|
||||||
from .language_presets import LanguagePresets
|
from .language_presets import LanguagePresets
|
||||||
from .prompt_rewriter import PromptRewriter
|
|
||||||
from .utils import random_completion_id, trim_with_stop_words
|
from .utils import random_completion_id, trim_with_stop_words
|
||||||
|
|
||||||
FLAGS_rewrite_prompt_with_search_snippet = os.environ.get(
|
|
||||||
"FLAGS_rewrite_prompt_with_search_snippet", None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TritonService:
|
class TritonService:
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -30,11 +24,6 @@ class TritonService:
|
||||||
url=f"{host}:{port}", verbose=verbose
|
url=f"{host}:{port}", verbose=verbose
|
||||||
)
|
)
|
||||||
|
|
||||||
if FLAGS_rewrite_prompt_with_search_snippet:
|
|
||||||
self.rewriter = PromptRewriter()
|
|
||||||
else:
|
|
||||||
self.rewriter = None
|
|
||||||
|
|
||||||
def generate(self, data: CompletionRequest) -> List[Choice]:
|
def generate(self, data: CompletionRequest) -> List[Choice]:
|
||||||
n = 1
|
n = 1
|
||||||
np_type = np.uint32
|
np_type = np.uint32
|
||||||
|
|
@ -44,10 +33,7 @@ class TritonService:
|
||||||
if preset is None:
|
if preset is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if self.rewriter:
|
prompt = data.prompt
|
||||||
prompt = self.rewriter(preset, data.prompt)
|
|
||||||
else:
|
|
||||||
prompt = data.prompt
|
|
||||||
|
|
||||||
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
input_start_ids = np.expand_dims(self.tokenizer.encode(prompt), 0)
|
||||||
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
input_start_ids = np.repeat(input_start_ids, n, axis=0).astype(np_type)
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ env:
|
||||||
- GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR"
|
- GIT_REPOSITORIES_DIR: "$GIT_REPOSITORIES_DIR"
|
||||||
- DATASET_DIR: "$DATASET_DIR"
|
- DATASET_DIR: "$DATASET_DIR"
|
||||||
- HOME: "$HOME"
|
- HOME: "$HOME"
|
||||||
- FLAGS_enable_meilisearch: "$FLAGS_enable_meilisearch"
|
|
||||||
steps:
|
steps:
|
||||||
- name: update repositories
|
- name: update repositories
|
||||||
dir: $APP_DIR
|
dir: $APP_DIR
|
||||||
|
|
@ -18,14 +17,3 @@ steps:
|
||||||
command: python -m tabby.tools.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
|
command: python -m tabby.tools.build_dataset --project_dir=$GIT_REPOSITORIES_DIR --output_dir=$DATASET_DIR
|
||||||
depends:
|
depends:
|
||||||
- update repositories
|
- update repositories
|
||||||
|
|
||||||
- name: refresh index
|
|
||||||
dir: $APP_DIR
|
|
||||||
preconditions:
|
|
||||||
- condition: "$FLAGS_enable_meilisearch"
|
|
||||||
expected: "1"
|
|
||||||
depends:
|
|
||||||
- generate dataset
|
|
||||||
command: |
|
|
||||||
curl -X DELETE 'http://localhost:8084/indexes/dataset/documents'
|
|
||||||
curl -X POST 'http://localhost:8084/indexes/dataset/documents?primaryKey=id' -H 'Content-Type: application/x-ndjson' --data-binary @$DATASET_DIR/dumps.json
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue