add-token-draft
Meng Zhang 2023-11-30 15:44:50 +08:00
parent de96d1b6af
commit 7f6af66d69
6 changed files with 45 additions and 15 deletions

View File

@ -15,5 +15,10 @@ class TextInferenceEngine {
virtual rust::Vec<StepOutput> step() = 0;
};
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
std::unique_ptr<TextInferenceEngine> create_engine(
bool use_gpu,
rust::Str model_path,
uint8_t paralellism,
bool enable_prompt_lookup
);
} // namespace

View File

@ -17,8 +17,8 @@ namespace {
constexpr size_t N_BATCH = 512; // # per batch inference.
constexpr size_t N_CTX = 4096; // # max kv history.
constexpr size_t DRAFT_N_GRAM_SIZE = 3;
constexpr size_t DRAFT_N_PRED_TOKENS = 10;
constexpr int DRAFT_N_GRAM_SIZE = 3;
constexpr int DRAFT_N_PRED_TOKENS = 10;
struct Request {
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
@ -142,10 +142,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;
class TextInferenceEngineImpl : public TextInferenceEngine {
public:
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism, bool enable_prompt_lookup) :
model_(std::move(model)),
ctx_(std::move(ctx)),
parallelism_(parallelism) {
parallelism_(parallelism),
enable_prompt_lookup_(enable_prompt_lookup) {
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
// warm up
{
@ -231,8 +232,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
const size_t n_tokens = batch_.n_tokens;
// Ensure the draft logits always fall into the same batch.
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
request.draft_tokens(n_draft_quota);
if (enable_prompt_lookup_) {
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
request.draft_tokens(n_draft_quota);
}
for (size_t i = 0; i < request.tokens.size(); ++i) {
batch_.token[n_tokens + i] = request.tokens[i];
@ -347,6 +350,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
std::unordered_set<uint32_t> stopped_requests_;
uint32_t parallelism_;
bool enable_prompt_lookup_;
};
static int g_llama_cpp_log_level = 0;
@ -374,7 +378,12 @@ struct BackendInitializer {
} // namespace
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
std::unique_ptr<TextInferenceEngine> create_engine(
bool use_gpu,
rust::Str model_path,
uint8_t parallelism,
bool enable_prompt_lookup
) {
static BackendInitializer initializer;
llama_model_params model_params = llama_model_default_params();
@ -397,7 +406,8 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
return std::make_unique<TextInferenceEngineImpl>(
owned<llama_model>(model, llama_free_model),
owned<llama_context>(ctx, llama_free),
parallelism
parallelism,
enable_prompt_lookup
);
}

View File

@ -27,6 +27,7 @@ mod ffi {
use_gpu: bool,
model_path: &str,
parallelism: u8,
enable_prompt_lookup: bool,
) -> UniquePtr<TextInferenceEngine>;
fn add_request(
@ -48,6 +49,7 @@ pub struct LlamaTextGenerationOptions {
model_path: String,
use_gpu: bool,
parallelism: u8,
enable_prompt_lookup: bool,
}
pub struct LlamaTextGeneration {
@ -57,7 +59,7 @@ pub struct LlamaTextGeneration {
impl LlamaTextGeneration {
pub fn new(options: LlamaTextGenerationOptions) -> Self {
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism, options.enable_prompt_lookup);
if engine.is_null() {
fatal!("Unable to load model: {}", options.model_path);
}

View File

@ -77,7 +77,7 @@ impl ChatService {
pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
let (engine, model::PromptInfo { chat_template, .. }) =
model::load_text_generation(model, device, parallelism).await;
model::load_text_generation(model, device, parallelism, true).await;
let Some(chat_template) = chat_template else {
fatal!("Chat model requires specifying prompt template");

View File

@ -281,7 +281,7 @@ pub async fn create_completion_service(
model::PromptInfo {
prompt_template, ..
},
) = model::load_text_generation(model, device, parallelism).await;
) = model::load_text_generation(model, device, parallelism, false).await;
CompletionService::new(engine.clone(), code, logger, prompt_template)
}

View File

@ -12,6 +12,7 @@ pub async fn load_text_generation(
model_id: &str,
device: &Device,
parallelism: u8,
enable_prompt_lookup: bool,
) -> (Arc<dyn TextGeneration>, PromptInfo) {
#[cfg(feature = "experimental-http")]
if device == &Device::ExperimentalHttp {
@ -28,19 +29,25 @@ pub async fn load_text_generation(
if fs::metadata(model_id).is_ok() {
let path = PathBuf::from(model_id);
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
let engine_info = PromptInfo::read(path.join("tabby.json"));
let engine = create_ggml_engine(
device,
model_path.display().to_string().as_str(),
parallelism,
enable_prompt_lookup,
);
let engine_info = PromptInfo::read(path.join("tabby.json"));
(Arc::new(engine), engine_info)
} else {
let (registry, name) = parse_model_id(model_id);
let registry = ModelRegistry::new(registry).await;
let model_path = registry.get_model_path(name).display().to_string();
let model_info = registry.get_model_info(name);
let engine = create_ggml_engine(device, &model_path, parallelism);
let engine = create_ggml_engine(
device,
&model_path,
parallelism,
enable_prompt_lookup,
);
(
Arc::new(engine),
PromptInfo {
@ -64,11 +71,17 @@ impl PromptInfo {
}
}
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
fn create_ggml_engine(
device: &Device,
model_path: &str,
parallelism: u8,
enable_prompt_lookup: bool,
) -> impl TextGeneration {
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
.model_path(model_path.to_owned())
.use_gpu(device.ggml_use_gpu())
.parallelism(parallelism)
.enable_prompt_lookup(enable_prompt_lookup)
.build()
.unwrap();