add-token-draft
Meng Zhang 2023-11-30 15:44:50 +08:00
parent de96d1b6af
commit 7f6af66d69
6 changed files with 45 additions and 15 deletions

View File

@ -15,5 +15,10 @@ class TextInferenceEngine {
virtual rust::Vec<StepOutput> step() = 0; virtual rust::Vec<StepOutput> step() = 0;
}; };
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism); std::unique_ptr<TextInferenceEngine> create_engine(
bool use_gpu,
rust::Str model_path,
uint8_t paralellism,
bool enable_prompt_lookup
);
} // namespace } // namespace

View File

@ -17,8 +17,8 @@ namespace {
constexpr size_t N_BATCH = 512; // # per batch inference. constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history. constexpr size_t N_CTX = 4096; // # max kv history.
constexpr size_t DRAFT_N_GRAM_SIZE = 3; constexpr int DRAFT_N_GRAM_SIZE = 3;
constexpr size_t DRAFT_N_PRED_TOKENS = 10; constexpr int DRAFT_N_PRED_TOKENS = 10;
struct Request { struct Request {
Request(size_t request_id, std::vector<llama_token> input_token_ids) : Request(size_t request_id, std::vector<llama_token> input_token_ids) :
@ -142,10 +142,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;
class TextInferenceEngineImpl : public TextInferenceEngine { class TextInferenceEngineImpl : public TextInferenceEngine {
public: public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) : TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism, bool enable_prompt_lookup) :
model_(std::move(model)), model_(std::move(model)),
ctx_(std::move(ctx)), ctx_(std::move(ctx)),
parallelism_(parallelism) { parallelism_(parallelism),
enable_prompt_lookup_(enable_prompt_lookup) {
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1); batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
// warm up // warm up
{ {
@ -231,8 +232,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
const size_t n_tokens = batch_.n_tokens; const size_t n_tokens = batch_.n_tokens;
// Ensure the draft logits always fall into the same batch. // Ensure the draft logits always fall into the same batch.
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH; if (enable_prompt_lookup_) {
request.draft_tokens(n_draft_quota); const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
request.draft_tokens(n_draft_quota);
}
for (size_t i = 0; i < request.tokens.size(); ++i) { for (size_t i = 0; i < request.tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.tokens[i]; batch_.token[n_tokens + i] = request.tokens[i];
@ -347,6 +350,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::unordered_set<uint32_t> stopped_requests_; std::unordered_set<uint32_t> stopped_requests_;
uint32_t parallelism_; uint32_t parallelism_;
bool enable_prompt_lookup_;
}; };
static int g_llama_cpp_log_level = 0; static int g_llama_cpp_log_level = 0;
@ -374,7 +378,12 @@ struct BackendInitializer {
} // namespace } // namespace
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) { std::unique_ptr<TextInferenceEngine> create_engine(
bool use_gpu,
rust::Str model_path,
uint8_t parallelism,
bool enable_prompt_lookup
) {
static BackendInitializer initializer; static BackendInitializer initializer;
llama_model_params model_params = llama_model_default_params(); llama_model_params model_params = llama_model_default_params();
@ -397,7 +406,8 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
return std::make_unique<TextInferenceEngineImpl>( return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model), owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free), owned<llama_context>(ctx, llama_free),
parallelism parallelism,
enable_prompt_lookup
); );
} }

View File

@ -27,6 +27,7 @@ mod ffi {
use_gpu: bool, use_gpu: bool,
model_path: &str, model_path: &str,
parallelism: u8, parallelism: u8,
enable_prompt_lookup: bool,
) -> UniquePtr<TextInferenceEngine>; ) -> UniquePtr<TextInferenceEngine>;
fn add_request( fn add_request(
@ -48,6 +49,7 @@ pub struct LlamaTextGenerationOptions {
model_path: String, model_path: String,
use_gpu: bool, use_gpu: bool,
parallelism: u8, parallelism: u8,
enable_prompt_lookup: bool,
} }
pub struct LlamaTextGeneration { pub struct LlamaTextGeneration {
@ -57,7 +59,7 @@ pub struct LlamaTextGeneration {
impl LlamaTextGeneration { impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self { pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism); let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism, options.enable_prompt_lookup);
if engine.is_null() { if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path); fatal!("Unable to load model: {}", options.model_path);
} }

View File

@ -77,7 +77,7 @@ impl ChatService {
pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService { pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
let (engine, model::PromptInfo { chat_template, .. }) = let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(model, device, parallelism).await; model::load_text_generation(model, device, parallelism, true).await;
let Some(chat_template) = chat_template else { let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template"); fatal!("Chat model requires specifying prompt template");

View File

@ -281,7 +281,7 @@ pub async fn create_completion_service(
model::PromptInfo { model::PromptInfo {
prompt_template, .. prompt_template, ..
}, },
) = model::load_text_generation(model, device, parallelism).await; ) = model::load_text_generation(model, device, parallelism, false).await;
CompletionService::new(engine.clone(), code, logger, prompt_template) CompletionService::new(engine.clone(), code, logger, prompt_template)
} }

View File

@ -12,6 +12,7 @@ pub async fn load_text_generation(
model_id: &str, model_id: &str,
device: &Device, device: &Device,
parallelism: u8, parallelism: u8,
enable_prompt_lookup: bool,
) -> (Arc<dyn TextGeneration>, PromptInfo) { ) -> (Arc<dyn TextGeneration>, PromptInfo) {
#[cfg(feature = "experimental-http")] #[cfg(feature = "experimental-http")]
if device == &Device::ExperimentalHttp { if device == &Device::ExperimentalHttp {
@ -28,19 +29,25 @@ pub async fn load_text_generation(
if fs::metadata(model_id).is_ok() { if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id); let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH); let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine_info = PromptInfo::read(path.join("tabby.json"));
let engine = create_ggml_engine( let engine = create_ggml_engine(
device, device,
model_path.display().to_string().as_str(), model_path.display().to_string().as_str(),
parallelism, parallelism,
enable_prompt_lookup,
); );
let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(engine), engine_info) (Arc::new(engine), engine_info)
} else { } else {
let (registry, name) = parse_model_id(model_id); let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await; let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string(); let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name); let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(device, &model_path, parallelism); let engine = create_ggml_engine(
device,
&model_path,
parallelism,
enable_prompt_lookup,
);
( (
Arc::new(engine), Arc::new(engine),
PromptInfo { PromptInfo {
@ -64,11 +71,17 @@ impl PromptInfo {
} }
} }
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration { fn create_ggml_engine(
device: &Device,
model_path: &str,
parallelism: u8,
enable_prompt_lookup: bool,
) -> impl TextGeneration {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default() let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned()) .model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu()) .use_gpu(device.ggml_use_gpu())
.parallelism(parallelism) .parallelism(parallelism)
.enable_prompt_lookup(enable_prompt_lookup)
.build() .build()
.unwrap(); .unwrap();