From 0d11b0e8327f0af19070ddd89f296c053395b8b6 Mon Sep 17 00:00:00 2001 From: Meng Zhang Date: Mon, 29 May 2023 16:58:02 -0700 Subject: [PATCH] Make language field optional in rust implementation (#164) * remove download_model.py as we have tabby serve now * Make language field optional --- crates/tabby/python/download_model.py | 31 ------------------- crates/tabby/src/serve/completions.rs | 5 +-- .../tabby/src/serve/completions/languages.rs | 4 +++ 3 files changed, 7 insertions(+), 33 deletions(-) delete mode 100755 crates/tabby/python/download_model.py diff --git a/crates/tabby/python/download_model.py b/crates/tabby/python/download_model.py deleted file mode 100755 index f8ca744..0000000 --- a/crates/tabby/python/download_model.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python3 - -from dataclasses import dataclass, field - -from huggingface_hub import snapshot_download -from transformers import HfArgumentParser - - -@dataclass -class Arguments: - repo_id: str = field( - metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"} - ) - device: str = field(metadata={"help": "Device type for inference: cpu / cuda"}) - output_dir: str = field(metadata={"help": "Output directory"}) - - -def parse_args(): - parser = HfArgumentParser(Arguments) - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - print(f"Loading {args.repo_id}, this will take a while...") - snapshot_download( - local_dir=args.output_dir, - repo_id=args.repo_id, - allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"], - ) - print(f"Loaded {args.repo_id} !") diff --git a/crates/tabby/src/serve/completions.rs b/crates/tabby/src/serve/completions.rs index 3f973d9..f634923 100644 --- a/crates/tabby/src/serve/completions.rs +++ b/crates/tabby/src/serve/completions.rs @@ -13,7 +13,7 @@ mod languages; pub struct CompletionRequest { /// https://code.visualstudio.com/docs/languages/identifiers #[schema(example = "python")] - language: String, + language: Option, #[schema(example = "def fib(n):")] prompt: String, @@ -47,7 +47,8 @@ pub async fn completion( .build() .unwrap(); let text = state.engine.inference(&request.prompt, options); - let filtered_text = languages::remove_stop_words(&request.language, &text); + let language = request.language.unwrap_or("unknown".into()); + let filtered_text = languages::remove_stop_words(&language, &text); Json(CompletionResponse { id: format!("cmpl-{}", uuid::Uuid::new_v4()), diff --git a/crates/tabby/src/serve/completions/languages.rs b/crates/tabby/src/serve/completions/languages.rs index f9cab10..48628da 100644 --- a/crates/tabby/src/serve/completions/languages.rs +++ b/crates/tabby/src/serve/completions/languages.rs @@ -6,6 +6,10 @@ lazy_static! { static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap(); static ref LANGUAGES: HashMap<&'static str, Regex> = { let mut map = HashMap::new(); + map.insert( + "unknown", + Regex::new(r"(?m)^(\n\n)").unwrap(), + ); map.insert( "python", Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),