update
parent
de96d1b6af
commit
7f6af66d69
|
|
@ -15,5 +15,10 @@ class TextInferenceEngine {
|
|||
virtual rust::Vec<StepOutput> step() = 0;
|
||||
};
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t paralellism);
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
bool use_gpu,
|
||||
rust::Str model_path,
|
||||
uint8_t paralellism,
|
||||
bool enable_prompt_lookup
|
||||
);
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -17,8 +17,8 @@ namespace {
|
|||
constexpr size_t N_BATCH = 512; // # per batch inference.
|
||||
constexpr size_t N_CTX = 4096; // # max kv history.
|
||||
|
||||
constexpr size_t DRAFT_N_GRAM_SIZE = 3;
|
||||
constexpr size_t DRAFT_N_PRED_TOKENS = 10;
|
||||
constexpr int DRAFT_N_GRAM_SIZE = 3;
|
||||
constexpr int DRAFT_N_PRED_TOKENS = 10;
|
||||
|
||||
struct Request {
|
||||
Request(size_t request_id, std::vector<llama_token> input_token_ids) :
|
||||
|
|
@ -142,10 +142,11 @@ using owned = std::unique_ptr<T, std::function<void(T*)>>;
|
|||
|
||||
class TextInferenceEngineImpl : public TextInferenceEngine {
|
||||
public:
|
||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism) :
|
||||
TextInferenceEngineImpl(owned<llama_model> model, owned<llama_context> ctx, uint8_t parallelism, bool enable_prompt_lookup) :
|
||||
model_(std::move(model)),
|
||||
ctx_(std::move(ctx)),
|
||||
parallelism_(parallelism) {
|
||||
parallelism_(parallelism),
|
||||
enable_prompt_lookup_(enable_prompt_lookup) {
|
||||
batch_ = llama_batch_init(N_CTX * parallelism, 0, 1);
|
||||
// warm up
|
||||
{
|
||||
|
|
@ -231,8 +232,10 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
const size_t n_tokens = batch_.n_tokens;
|
||||
|
||||
// Ensure the draft logits always fall into the same batch.
|
||||
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
|
||||
request.draft_tokens(n_draft_quota);
|
||||
if (enable_prompt_lookup_) {
|
||||
const int n_draft_quota = N_BATCH - (n_tokens + request.tokens.size()) % N_BATCH;
|
||||
request.draft_tokens(n_draft_quota);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < request.tokens.size(); ++i) {
|
||||
batch_.token[n_tokens + i] = request.tokens[i];
|
||||
|
|
@ -347,6 +350,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
std::unordered_set<uint32_t> stopped_requests_;
|
||||
|
||||
uint32_t parallelism_;
|
||||
bool enable_prompt_lookup_;
|
||||
};
|
||||
|
||||
static int g_llama_cpp_log_level = 0;
|
||||
|
|
@ -374,7 +378,12 @@ struct BackendInitializer {
|
|||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model_path, uint8_t parallelism) {
|
||||
std::unique_ptr<TextInferenceEngine> create_engine(
|
||||
bool use_gpu,
|
||||
rust::Str model_path,
|
||||
uint8_t parallelism,
|
||||
bool enable_prompt_lookup
|
||||
) {
|
||||
static BackendInitializer initializer;
|
||||
|
||||
llama_model_params model_params = llama_model_default_params();
|
||||
|
|
@ -397,7 +406,8 @@ std::unique_ptr<TextInferenceEngine> create_engine(bool use_gpu, rust::Str model
|
|||
return std::make_unique<TextInferenceEngineImpl>(
|
||||
owned<llama_model>(model, llama_free_model),
|
||||
owned<llama_context>(ctx, llama_free),
|
||||
parallelism
|
||||
parallelism,
|
||||
enable_prompt_lookup
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ mod ffi {
|
|||
use_gpu: bool,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
enable_prompt_lookup: bool,
|
||||
) -> UniquePtr<TextInferenceEngine>;
|
||||
|
||||
fn add_request(
|
||||
|
|
@ -48,6 +49,7 @@ pub struct LlamaTextGenerationOptions {
|
|||
model_path: String,
|
||||
use_gpu: bool,
|
||||
parallelism: u8,
|
||||
enable_prompt_lookup: bool,
|
||||
}
|
||||
|
||||
pub struct LlamaTextGeneration {
|
||||
|
|
@ -57,7 +59,7 @@ pub struct LlamaTextGeneration {
|
|||
|
||||
impl LlamaTextGeneration {
|
||||
pub fn new(options: LlamaTextGenerationOptions) -> Self {
|
||||
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism);
|
||||
let engine = create_engine(options.use_gpu, &options.model_path, options.parallelism, options.enable_prompt_lookup);
|
||||
if engine.is_null() {
|
||||
fatal!("Unable to load model: {}", options.model_path);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ impl ChatService {
|
|||
|
||||
pub async fn create_chat_service(model: &str, device: &Device, parallelism: u8) -> ChatService {
|
||||
let (engine, model::PromptInfo { chat_template, .. }) =
|
||||
model::load_text_generation(model, device, parallelism).await;
|
||||
model::load_text_generation(model, device, parallelism, true).await;
|
||||
|
||||
let Some(chat_template) = chat_template else {
|
||||
fatal!("Chat model requires specifying prompt template");
|
||||
|
|
|
|||
|
|
@ -281,7 +281,7 @@ pub async fn create_completion_service(
|
|||
model::PromptInfo {
|
||||
prompt_template, ..
|
||||
},
|
||||
) = model::load_text_generation(model, device, parallelism).await;
|
||||
) = model::load_text_generation(model, device, parallelism, false).await;
|
||||
|
||||
CompletionService::new(engine.clone(), code, logger, prompt_template)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ pub async fn load_text_generation(
|
|||
model_id: &str,
|
||||
device: &Device,
|
||||
parallelism: u8,
|
||||
enable_prompt_lookup: bool,
|
||||
) -> (Arc<dyn TextGeneration>, PromptInfo) {
|
||||
#[cfg(feature = "experimental-http")]
|
||||
if device == &Device::ExperimentalHttp {
|
||||
|
|
@ -28,19 +29,25 @@ pub async fn load_text_generation(
|
|||
if fs::metadata(model_id).is_ok() {
|
||||
let path = PathBuf::from(model_id);
|
||||
let model_path = path.join(GGML_MODEL_RELATIVE_PATH);
|
||||
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
||||
let engine = create_ggml_engine(
|
||||
device,
|
||||
model_path.display().to_string().as_str(),
|
||||
parallelism,
|
||||
enable_prompt_lookup,
|
||||
);
|
||||
let engine_info = PromptInfo::read(path.join("tabby.json"));
|
||||
(Arc::new(engine), engine_info)
|
||||
} else {
|
||||
let (registry, name) = parse_model_id(model_id);
|
||||
let registry = ModelRegistry::new(registry).await;
|
||||
let model_path = registry.get_model_path(name).display().to_string();
|
||||
let model_info = registry.get_model_info(name);
|
||||
let engine = create_ggml_engine(device, &model_path, parallelism);
|
||||
let engine = create_ggml_engine(
|
||||
device,
|
||||
&model_path,
|
||||
parallelism,
|
||||
enable_prompt_lookup,
|
||||
);
|
||||
(
|
||||
Arc::new(engine),
|
||||
PromptInfo {
|
||||
|
|
@ -64,11 +71,17 @@ impl PromptInfo {
|
|||
}
|
||||
}
|
||||
|
||||
fn create_ggml_engine(device: &Device, model_path: &str, parallelism: u8) -> impl TextGeneration {
|
||||
fn create_ggml_engine(
|
||||
device: &Device,
|
||||
model_path: &str,
|
||||
parallelism: u8,
|
||||
enable_prompt_lookup: bool,
|
||||
) -> impl TextGeneration {
|
||||
let options = llama_cpp_bindings::LlamaTextGenerationOptionsBuilder::default()
|
||||
.model_path(model_path.to_owned())
|
||||
.use_gpu(device.ggml_use_gpu())
|
||||
.parallelism(parallelism)
|
||||
.enable_prompt_lookup(enable_prompt_lookup)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue