Make language field optional in rust implementation (#164)
* remove download_model.py as we have tabby serve now * Make language field optionaladd-prefix-suffix
parent
a9d74f7a35
commit
0d11b0e832
|
|
@ -1,31 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@dataclass
|
||||
class Arguments:
|
||||
repo_id: str = field(
|
||||
metadata={"help": "Huggingface model repository id, e.g TabbyML/NeoX-160M"}
|
||||
)
|
||||
device: str = field(metadata={"help": "Device type for inference: cpu / cuda"})
|
||||
output_dir: str = field(metadata={"help": "Output directory"})
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = HfArgumentParser(Arguments)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
print(f"Loading {args.repo_id}, this will take a while...")
|
||||
snapshot_download(
|
||||
local_dir=args.output_dir,
|
||||
repo_id=args.repo_id,
|
||||
allow_patterns=[f"ctranslate2/{args.device}/*", "tokenizer.json"],
|
||||
)
|
||||
print(f"Loaded {args.repo_id} !")
|
||||
|
|
@ -13,7 +13,7 @@ mod languages;
|
|||
pub struct CompletionRequest {
|
||||
/// https://code.visualstudio.com/docs/languages/identifiers
|
||||
#[schema(example = "python")]
|
||||
language: String,
|
||||
language: Option<String>,
|
||||
|
||||
#[schema(example = "def fib(n):")]
|
||||
prompt: String,
|
||||
|
|
@ -47,7 +47,8 @@ pub async fn completion(
|
|||
.build()
|
||||
.unwrap();
|
||||
let text = state.engine.inference(&request.prompt, options);
|
||||
let filtered_text = languages::remove_stop_words(&request.language, &text);
|
||||
let language = request.language.unwrap_or("unknown".into());
|
||||
let filtered_text = languages::remove_stop_words(&language, &text);
|
||||
|
||||
Json(CompletionResponse {
|
||||
id: format!("cmpl-{}", uuid::Uuid::new_v4()),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,10 @@ lazy_static! {
|
|||
static ref DEFAULT: Regex = Regex::new(r"(?m)^\n\n").unwrap();
|
||||
static ref LANGUAGES: HashMap<&'static str, Regex> = {
|
||||
let mut map = HashMap::new();
|
||||
map.insert(
|
||||
"unknown",
|
||||
Regex::new(r"(?m)^(\n\n)").unwrap(),
|
||||
);
|
||||
map.insert(
|
||||
"python",
|
||||
Regex::new(r"(?m)^(\n\n|def|#|from|class)").unwrap(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue