Make language field optional in rust implementation (#164)
* remove download_model.py as we have tabby serve now * Make language field optionaladd-prefix-suffix
parent
a9d74f7a35
commit
0d11b0e832
|
|
@ -1,31 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from transformers import HfArgumentParser
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Arguments:
|
|
||||||
repo_id: str = field(
|
|
||||||
metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"}
|
|
||||||
)
|
|
||||||
device: str = field(metadata={"help": "Device type for inference: cpu / cuda"})
|
|
||||||
output_dir: str = field(metadata={"help": "Output directory"})
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
parser = HfArgumentParser(Arguments)
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
args = parse_args()
|
|
||||||
print(f"Loading {args.repo_id}, this will take a while...")
|
|
||||||
snapshot_download(
|
|
||||||
local_dir=args.output_dir,
|
|
||||||
repo_id=args.repo_id,
|
|
||||||
allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"],
|
|
||||||
)
|
|
||||||
print(f"Loaded {args.repo_id} !")
|
|
||||||
|
|
@ -13,7 +13,7 @@ mod languages;
|
||||||
pub struct CompletionRequest {
|
pub struct CompletionRequest {
|
||||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||||
#[schema(example = "python")]
|
#[schema(example = "python")]
|
||||||
language: String,
|
language: Option<String>,
|
||||||
|
|
||||||
#[schema(example = "def fib(n):")]
|
#[schema(example = "def fib(n):")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
@ -47,7 +47,8 @@ pub async fn completion(
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let text = state.engine.inference(&request.prompt, options);
|
let text = state.engine.inference(&request.prompt, options);
|
||||||
let filtered_text = languages::remove_stop_words(&request.language, &text);
|
let language = request.language.unwrap_or("unknown".into());
|
||||||
|
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||||
|
|
||||||
Json(CompletionResponse {
|
Json(CompletionResponse {
|
||||||
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
|
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,10 @@ lazy_static! {
|
||||||
static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap();
|
static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap();
|
||||||
static ref LANGUAGES: HashMap<&'static str, Regex> = {
|
static ref LANGUAGES: HashMap<&'static str, Regex> = {
|
||||||
let mut map = HashMap::new();
|
let mut map = HashMap::new();
|
||||||
|
map.insert(
|
||||||
|
"unknown",
|
||||||
|
Regex::new(r"(?m)^(\n\n)").unwrap(),
|
||||||
|
);
|
||||||
map.insert(
|
map.insert(
|
||||||
"python",
|
"python",
|
||||||
Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),
|
Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue