Make language field optional in rust implementation (#164)

* remove download_model.py as we have tabby serve now

* Make language field optional
add-prefix-suffix
Meng Zhang 2023-05-29 16:58:02 -07:00 committed by GitHub
parent a9d74f7a35
commit 0d11b0e832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 33 deletions

View File

@ -1,31 +0,0 @@
#!/usr/bin/env python3
from dataclasses import dataclass, field
from huggingface_hub import snapshot_download
from transformers import HfArgumentParser
@dataclass
class Arguments:
repo_id: str = field(
metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"}
)
device: str = field(metadata={"help": "Device type for inference: cpu / cuda"})
output_dir: str = field(metadata={"help": "Output directory"})
def parse_args():
parser = HfArgumentParser(Arguments)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
print(f"Loading {args.repo_id}, this will take a while...")
snapshot_download(
local_dir=args.output_dir,
repo_id=args.repo_id,
allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"],
)
print(f"Loaded {args.repo_id} !")

View File

@ -13,7 +13,7 @@ mod languages;
pub struct CompletionRequest {
/// https://code.visualstudio.com/docs/languages/identifiers
#[schema(example = "python")]
language: String,
language: Option<String>,
#[schema(example = "def fib(n):")]
prompt: String,
@ -47,7 +47,8 @@ pub async fn completion(
.build()
.unwrap();
let text = state.engine.inference(&request.prompt, options);
let filtered_text = languages::remove_stop_words(&request.language, &text);
let language = request.language.unwrap_or("unknown".into());
let filtered_text = languages::remove_stop_words(&language, &text);
Json(CompletionResponse {
id: format!("cmpl-{}", uuid::Uuid::new_v4()),

View File

@ -6,6 +6,10 @@ lazy_static! {
static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap();
static ref LANGUAGES: HashMap<&'static str, Regex> = {
let mut map = HashMap::new();
map.insert(
"unknown",
Regex::new(r"(?m)^(\n\n)").unwrap(),
);
map.insert(
"python",
Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),