update
parent
de96d1b6af
commit
7f6af66d69
|
|
@ -15,5 +15,10 @@ class TextInferenceEngine {
|
||||||
virtual rust::Vec<StepOutput> step() = 0;
|
virtual rust::Vec<StepOutput> step() = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
|
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||||
|
bool use_gpu,
|
||||||
|
rust::Str model_path,
|
||||||
|
uint8_t paralellism,
|
||||||
|
bool enable_prompt_lookup
|
||||||
|
);
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -17,8 +17,8 @@ namespace {
|
||||||
constexpr size_t N_BATCH = 512; // # per batch inference.
|
constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||||
constexpr size_t N_CTX = 4096; // # max kv history.
|
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||||
|
|
||||||
constexpr size_t DRAFT_N_GRAM_SIZE = 3;
|
constexpr int DRAFT_N_GRAM_SIZE = 3;
|
||||||
constexpr size_t DRAFT_N_PRED_TOKENS = 10;
|
constexpr int DRAFT_N_PRED_TOKENS = 10;
|
||||||
|
|
||||||
struct Request {
|
struct Request {
|
||||||
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
||||||
|
|
@ -142,10 +142,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
||||||
|
|
||||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
public:
|
public:
|
||||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
|
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism, bool enable_prompt_lookup) :
|
||||||
model_(std::move(model)),
|
model_(std::move(model)),
|
||||||
ctx_(std::move(ctx)),
|
ctx_(std::move(ctx)),
|
||||||
parallelism_(parallelism) {
|
parallelism_(parallelism),
|
||||||
|
enable_prompt_lookup_(enable_prompt_lookup) {
|
||||||
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
|
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
|
||||||
// warm up
|
// warm up
|
||||||
{
|
{
|
||||||
|
|
@ -231,8 +232,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
const size_t n_tokens = batch_.n_tokens;
|
const size_t n_tokens = batch_.n_tokens;
|
||||||
|
|
||||||
// Ensure the draft logits always fall into the same batch.
|
// Ensure the draft logits always fall into the same batch.
|
||||||
|
if (enable_prompt_lookup_) {
|
||||||
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
|
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
|
||||||
request.draft_tokens(n_draft_quota);
|
request.draft_tokens(n_draft_quota);
|
||||||
|
}
|
||||||
|
|
||||||
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
||||||
batch_.token[n_tokens + i] = request.tokens[i];
|
batch_.token[n_tokens + i] = request.tokens[i];
|
||||||
|
|
@ -347,6 +350,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||||
std::unordered_set<uint32_t> stopped_requests_;
|
std::unordered_set<uint32_t> stopped_requests_;
|
||||||
|
|
||||||
uint32_t parallelism_;
|
uint32_t parallelism_;
|
||||||
|
bool enable_prompt_lookup_;
|
||||||
};
|
};
|
||||||
|
|
||||||
static int g_llama_cpp_log_level = 0;
|
static int g_llama_cpp_log_level = 0;
|
||||||
|
|
@ -374,7 +378,12 @@ struct BackendInitializer {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
|
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||||
|
bool use_gpu,
|
||||||
|
rust::Str model_path,
|
||||||
|
uint8_t parallelism,
|
||||||
|
bool enable_prompt_lookup
|
||||||
|
) {
|
||||||
static BackendInitializer initializer;
|
static BackendInitializer initializer;
|
||||||
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
|
@ -397,7 +406,8 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
|
||||||
return std::make_unique<TextInferenceEngineImpl>(
|
return std::make_unique<TextInferenceEngineImpl>(
|
||||||
owned<llama_model>(model, llama_free_model),
|
owned<llama_model>(model, llama_free_model),
|
||||||
owned<llama_context>(ctx, llama_free),
|
owned<llama_context>(ctx, llama_free),
|
||||||
parallelism
|
parallelism,
|
||||||
|
enable_prompt_lookup
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ mod ffi {
|
||||||
use_gpu: bool,
|
use_gpu: bool,
|
||||||
model_path: &str,
|
model_path: &str,
|
||||||
parallelism: u8,
|
parallelism: u8,
|
||||||
|
enable_prompt_lookup: bool,
|
||||||
) -> UniquePtr<TextInferenceEngine>;
|
) -> UniquePtr<TextInferenceEngine>;
|
||||||
|
|
||||||
fn add_request(
|
fn add_request(
|
||||||
|
|
@ -48,6 +49,7 @@ pub struct LlamaTextGenerationOptions {
|
||||||
model_path: String,
|
model_path: String,
|
||||||
use_gpu: bool,
|
use_gpu: bool,
|
||||||
parallelism: u8,
|
parallelism: u8,
|
||||||
|
enable_prompt_lookup: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct LlamaTextGeneration {
|
pub struct LlamaTextGeneration {
|
||||||
|
|
@ -57,7 +59,7 @@ pub struct LlamaTextGeneration {
|
||||||
|
|
||||||
impl LlamaTextGeneration {
|
impl LlamaTextGeneration {
|
||||||
pub fn new(options: LlamaTextGenerationOptions) -> Self {
|
pub fn new(options: LlamaTextGenerationOptions) -> Self {
|
||||||
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
|
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism, options.enable_prompt_lookup);
|
||||||
if engine.is_null() {
|
if engine.is_null() {
|
||||||
fatal!("Unable to load model: {}", options.model_path);
|
fatal!("Unable to load model: {}", options.model_path);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -77,7 +77,7 @@ impl ChatService {
|
||||||
|
|
||||||
pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
|
pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
|
||||||
let (engine, model::PromptInfo { chat_template, .. }) =
|
let (engine, model::PromptInfo { chat_template, .. }) =
|
||||||
model::load_text_generation(model, device, parallelism).await;
|
model::load_text_generation(model, device, parallelism, true).await;
|
||||||
|
|
||||||
let Some(chat_template) = chat_template else {
|
let Some(chat_template) = chat_template else {
|
||||||
fatal!("Chat model requires specifying prompt template");
|
fatal!("Chat model requires specifying prompt template");
|
||||||
|
|
|
||||||
|
|
@ -281,7 +281,7 @@ pub async fn create_completion_service(
|
||||||
model::PromptInfo {
|
model::PromptInfo {
|
||||||
prompt_template, ..
|
prompt_template, ..
|
||||||
},
|
},
|
||||||
) = model::load_text_generation(model, device, parallelism).await;
|
) = model::load_text_generation(model, device, parallelism, false).await;
|
||||||
|
|
||||||
CompletionService::new(engine.clone(), code, logger, prompt_template)
|
CompletionService::new(engine.clone(), code, logger, prompt_template)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@ pub async fn load_text_generation(
|
||||||
model_id: &str,
|
model_id: &str,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
parallelism: u8,
|
parallelism: u8,
|
||||||
|
enable_prompt_lookup: bool,
|
||||||
) -> (Arc<dyn TextGeneration>, PromptInfo) {
|
) -> (Arc<dyn TextGeneration>, PromptInfo) {
|
||||||
#[cfg(feature = "experimental-http")]
|
#[cfg(feature = "experimental-http")]
|
||||||
if device == &Device::ExperimentalHttp {
|
if device == &Device::ExperimentalHttp {
|
||||||
|
|
@ -28,19 +29,25 @@ pub async fn load_text_generation(
|
||||||
if fs::metadata(model_id).is_ok() {
|
if fs::metadata(model_id).is_ok() {
|
||||||
let path = PathBuf::from(model_id);
|
let path = PathBuf::from(model_id);
|
||||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||||
|
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
||||||
let engine = create_ggml_engine(
|
let engine = create_ggml_engine(
|
||||||
device,
|
device,
|
||||||
model_path.display().to_string().as_str(),
|
model_path.display().to_string().as_str(),
|
||||||
parallelism,
|
parallelism,
|
||||||
|
enable_prompt_lookup,
|
||||||
);
|
);
|
||||||
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
|
||||||
(Arc::new(engine), engine_info)
|
(Arc::new(engine), engine_info)
|
||||||
} else {
|
} else {
|
||||||
let (registry, name) = parse_model_id(model_id);
|
let (registry, name) = parse_model_id(model_id);
|
||||||
let registry = ModelRegistry::new(registry).await;
|
let registry = ModelRegistry::new(registry).await;
|
||||||
let model_path = registry.get_model_path(name).display().to_string();
|
let model_path = registry.get_model_path(name).display().to_string();
|
||||||
let model_info = registry.get_model_info(name);
|
let model_info = registry.get_model_info(name);
|
||||||
let engine = create_ggml_engine(device, &model_path, parallelism);
|
let engine = create_ggml_engine(
|
||||||
|
device,
|
||||||
|
&model_path,
|
||||||
|
parallelism,
|
||||||
|
enable_prompt_lookup,
|
||||||
|
);
|
||||||
(
|
(
|
||||||
Arc::new(engine),
|
Arc::new(engine),
|
||||||
PromptInfo {
|
PromptInfo {
|
||||||
|
|
@ -64,11 +71,17 @@ impl PromptInfo {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
|
fn create_ggml_engine(
|
||||||
|
device: &Device,
|
||||||
|
model_path: &str,
|
||||||
|
parallelism: u8,
|
||||||
|
enable_prompt_lookup: bool,
|
||||||
|
) -> impl TextGeneration {
|
||||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||||
.model_path(model_path.to_owned())
|
.model_path(model_path.to_owned())
|
||||||
.use_gpu(device.ggml_use_gpu())
|
.use_gpu(device.ggml_use_gpu())
|
||||||
.parallelism(parallelism)
|
.parallelism(parallelism)
|
||||||
|
.enable_prompt_lookup(enable_prompt_lookup)
|
||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue