refactor
parent
30b42bf186
commit
d1c4db52a8
|
|
@ -36,32 +36,16 @@ struct Request {
|
|||
|
||||
|
||||
std::vector<llama_token> all_tokens;
|
||||
std::vector<llama_token> find_candidate_pred_tokens(size_t max_ngram_size = 3, size_t n_pred_tokens = 10) {
|
||||
for (size_t ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) {
|
||||
if (all_tokens.size() < ngram_size) continue;
|
||||
std::vector<llama_token> ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end());
|
||||
std::vector<llama_token> find_candidate_pred_tokens(size_t ngram_size = 3, size_t n_pred_tokens = 10) {
|
||||
if (all_tokens.size() < ngram_size) return {};
|
||||
std::vector<llama_token> ngram(all_tokens.begin() + all_tokens.size() - ngram_size, all_tokens.end());
|
||||
|
||||
const int matched = find_ngram(ngram, n_pred_tokens);
|
||||
if (matched < 0) continue;
|
||||
const auto end = all_tokens.end() - ngram_size - n_pred_tokens;
|
||||
const auto matched = std::search(all_tokens.begin(), end, ngram.begin(), ngram.end());
|
||||
if (matched == end) return {};
|
||||
|
||||
const int offset = matched + ngram_size;
|
||||
return std::vector<llama_token>(all_tokens.begin() + offset, all_tokens.begin() + offset + n_pred_tokens);
|
||||
}
|
||||
|
||||
return std::vector<llama_token>();
|
||||
}
|
||||
private:
|
||||
int find_ngram(const std::vector<llama_token> & ngram, size_t n_pred_tokens) {
|
||||
const int max = static_cast<int>(all_tokens.size()) - ngram.size() - n_pred_tokens;
|
||||
for (int i = 0; i < max; ++i) {
|
||||
const auto mismatch = std::mismatch(all_tokens.begin() + i, all_tokens.begin() + i + ngram.size(), ngram.begin());
|
||||
if (mismatch.second == ngram.end()) {
|
||||
// Matched
|
||||
return i;
|
||||
}
|
||||
}
|
||||
|
||||
return -1;
|
||||
const auto begin = matched + ngram_size;
|
||||
return std::vector<llama_token>(begin, begin + n_pred_tokens);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -259,6 +243,7 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
request.all_tokens.push_back(next_token);
|
||||
const auto token_str = llama_token_to_piece(ctx, next_token);
|
||||
request.generated_text += token_str;
|
||||
request.n_past += 1;
|
||||
|
||||
// FIXME: Hack for codellama to simplify tabby's implementation.
|
||||
const bool is_eos = next_token == eos_id || token_str == " <EOT>";
|
||||
|
|
@ -299,8 +284,6 @@ class TextInferenceEngineImpl : public TextInferenceEngine {
|
|||
llama_kv_cache_seq_rm(ctx_.get(), request.id, request.n_past, -1);
|
||||
break;
|
||||
}
|
||||
|
||||
request.n_past += 1;
|
||||
}
|
||||
|
||||
request.tokens.clear();
|
||||
|
|
|
|||
Loading…
Reference in New Issue