diff --git a/crates/llama-cpp-bindings/include/engine.h b/crates/llama-cpp-bindings/include/engine.h index fa8743c..7754430 100644 --- a/crates/llama-cpp-bindings/include/engine.h +++ b/crates/llama-cpp-bindings/include/engine.h @@ -15,5 +15,10 @@ class TextInferenceEngine { virtual rust::Vec step() = 0; }; -std::unique_ptr create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism); +std::unique_ptr create_engine( + bool use_gpu, + rust::Str model_path, + uint8_t paralellism, + bool enable_prompt_lookup +); } // namespace diff --git a/crates/llama-cpp-bindings/src/engine.cc b/crates/llama-cpp-bindings/src/engine.cc index 9a9b7b6..d32fc8a 100644 --- a/crates/llama-cpp-bindings/src/engine.cc +++ b/crates/llama-cpp-bindings/src/engine.cc @@ -17,8 +17,8 @@ namespace { constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_CTX = 4096; // # max kv history. -constexpr size_t DRAFT_N_GRAM_SIZE = 3; -constexpr size_t DRAFT_N_PRED_TOKENS = 10; +constexpr int DRAFT_N_GRAM_SIZE = 3; +constexpr int DRAFT_N_PRED_TOKENS = 10; struct Request { Request(size_t request_id, std::vector input_token_ids) : @@ -142,10 +142,11 @@ using owned = std::unique_ptr>; class TextInferenceEngineImpl : public TextInferenceEngine { public: - TextInferenceEngineImpl(owned model, owned ctx, uint8_t parallelism) : + TextInferenceEngineImpl(owned model, owned ctx, uint8_t parallelism, bool enable_prompt_lookup) : model_(std::move(model)), ctx_(std::move(ctx)), - parallelism_(parallelism) { + parallelism_(parallelism), + enable_prompt_lookup_(enable_prompt_lookup) { batch_ = llama_batch_init(N_CTX * parallelism, 0, 1); // warm up { @@ -231,8 +232,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine { const size_t n_tokens = batch_.n_tokens; // Ensure the draft logits always fall into the same batch. - const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH; - request.draft_tokens(n_draft_quota); + if (enable_prompt_lookup_) { + const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH; + request.draft_tokens(n_draft_quota); + } for (size_t i = 0; i < request.tokens.size(); ++i) { batch_.token[n_tokens + i] = request.tokens[i]; @@ -347,6 +350,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine { std::unordered_set stopped_requests_; uint32_t parallelism_; + bool enable_prompt_lookup_; }; static int g_llama_cpp_log_level = 0; @@ -374,7 +378,12 @@ struct BackendInitializer { } // namespace -std::unique_ptr create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) { +std::unique_ptr create_engine( + bool use_gpu, + rust::Str model_path, + uint8_t parallelism, + bool enable_prompt_lookup +) { static BackendInitializer initializer; llama_model_params model_params = llama_model_default_params(); @@ -397,7 +406,8 @@ std::unique_ptr create_engine(bool use_gpu, rust::Str model return std::make_unique( owned(model, llama_free_model), owned(ctx, llama_free), - parallelism + parallelism, + enable_prompt_lookup ); } diff --git a/crates/llama-cpp-bindings/src/lib.rs b/crates/llama-cpp-bindings/src/lib.rs index 0dbdaf9..ac6a3e8 100644 --- a/crates/llama-cpp-bindings/src/lib.rs +++ b/crates/llama-cpp-bindings/src/lib.rs @@ -27,6 +27,7 @@ mod ffi { use_gpu: bool, model_path: &str, parallelism: u8, + enable_prompt_lookup: bool, ) -> UniquePtr; fn add_request( @@ -48,6 +49,7 @@ pub struct LlamaTextGenerationOptions { model_path: String, use_gpu: bool, parallelism: u8, + enable_prompt_lookup: bool, } pub struct LlamaTextGeneration { @@ -57,7 +59,7 @@ pub struct LlamaTextGeneration { impl LlamaTextGeneration { pub fn new(options: LlamaTextGenerationOptions) -> Self { - let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism); + let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism, options.enable_prompt_lookup); if engine.is_null() { fatal!("Unable to load model: {}", options.model_path); } diff --git a/crates/tabby/src/services/chat.rs b/crates/tabby/src/services/chat.rs index e81096d..0d4003d 100644 --- a/crates/tabby/src/services/chat.rs +++ b/crates/tabby/src/services/chat.rs @@ -77,7 +77,7 @@ impl ChatService { pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService { let (engine, model::PromptInfo { chat_template, .. }) = - model::load_text_generation(model, device, parallelism).await; + model::load_text_generation(model, device, parallelism, true).await; let Some(chat_template) = chat_template else { fatal!("Chat model requires specifying prompt template"); diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index ecdc2df..72f4e30 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -281,7 +281,7 @@ pub async fn create_completion_service( model::PromptInfo { prompt_template, .. }, - ) = model::load_text_generation(model, device, parallelism).await; + ) = model::load_text_generation(model, device, parallelism, false).await; CompletionService::new(engine.clone(), code, logger, prompt_template) } diff --git a/crates/tabby/src/services/model.rs b/crates/tabby/src/services/model.rs index 8c5bd48..9766e23 100644 --- a/crates/tabby/src/services/model.rs +++ b/crates/tabby/src/services/model.rs @@ -12,6 +12,7 @@ pub async fn load_text_generation( model_id: &str, device: &Device, parallelism: u8, + enable_prompt_lookup: bool, ) -> (Arc, PromptInfo) { #[cfg(feature = "experimental-http")] if device == &Device::ExperimentalHttp { @@ -28,19 +29,25 @@ pub async fn load_text_generation( if fs::metadata(model_id).is_ok() { let path = PathBuf::from(model_id); let model_path = path.join(GGML_MODEL_RELATIVE_PATH); + let engine_info = PromptInfo::read(path.join("tabby.json")); let engine = create_ggml_engine( device, model_path.display().to_string().as_str(), parallelism, + enable_prompt_lookup, ); - let engine_info = PromptInfo::read(path.join("tabby.json")); (Arc::new(engine), engine_info) } else { let (registry, name) = parse_model_id(model_id); let registry = ModelRegistry::new(registry).await; let model_path = registry.get_model_path(name).display().to_string(); let model_info = registry.get_model_info(name); - let engine = create_ggml_engine(device, &model_path, parallelism); + let engine = create_ggml_engine( + device, + &model_path, + parallelism, + enable_prompt_lookup, + ); ( Arc::new(engine), PromptInfo { @@ -64,11 +71,17 @@ impl PromptInfo { } } -fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration { +fn create_ggml_engine( + device: &Device, + model_path: &str, + parallelism: u8, + enable_prompt_lookup: bool, +) -> impl TextGeneration { let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() .model_path(model_path.to_owned()) .use_gpu(device.ggml_use_gpu()) .parallelism(parallelism) + .enable_prompt_lookup(enable_prompt_lookup) .build() .unwrap();