tabby/clients/tabby-agent/src/CompletionCache.ts

145 lines
4.3 KiB
TypeScript
Raw Normal View History

2023-05-29 02:09:44 +00:00
import { LRUCache } from "lru-cache";
import hashObject from "object-hash";
import sizeOfObject from "object-sizeof";
import { CompletionRequest, CompletionResponse } from "./Agent";
2023-06-06 14:25:31 +00:00
import { rootLogger } from "./logger";
2023-05-29 02:09:44 +00:00
import { splitLines, splitWords } from "./utils";
type CompletionCacheKey = CompletionRequest;
type CompletionCacheValue = CompletionResponse;
export class CompletionCache {
2023-06-06 14:25:31 +00:00
private readonly logger = rootLogger.child({ component: "CompletionCache" });
2023-05-29 02:09:44 +00:00
private cache: LRUCache<string, CompletionCacheValue>;
private options = {
maxSize: 1 * 1024 * 1024, // 1MB
partiallyAcceptedCacheGeneration: {
enabled: true,
perCharacter: {
lines: 1,
words: 10,
max: 30,
},
perWord: {
lines: 1,
max: 20,
},
perLine: {
max: 3,
},
},
};
constructor() {
this.cache = new LRUCache<string, CompletionCacheValue>({
maxSize: this.options.maxSize,
sizeCalculation: sizeOfObject,
});
}
has(key: CompletionCacheKey): boolean {
return this.cache.has(this.hash(key));
}
set(key: CompletionCacheKey, value: CompletionCacheValue): void {
for (const entry of this.createCacheEntries(key, value)) {
2023-06-06 14:25:31 +00:00
this.logger.debug({ entry }, "Setting cache entry");
2023-05-29 02:09:44 +00:00
this.cache.set(this.hash(entry.key), entry.value);
}
2023-06-06 14:25:31 +00:00
this.logger.debug({ size: this.cache.calculatedSize }, "Cache size");
2023-05-29 02:09:44 +00:00
}
get(key: CompletionCacheKey): CompletionCacheValue | undefined {
return this.cache.get(this.hash(key));
}
private hash(key: CompletionCacheKey): string {
return hashObject(key);
}
private createCacheEntries(
key: CompletionCacheKey,
value: CompletionCacheValue,
2023-05-29 02:09:44 +00:00
): { key: CompletionCacheKey; value: CompletionCacheValue }[] {
const list = [{ key, value }];
if (this.options.partiallyAcceptedCacheGeneration.enabled) {
const entries = value.choices
.map((choice) => {
return this.calculatePartiallyAcceptedPositions(choice.text).map((position) => {
return {
prefix: choice.text.slice(0, position),
suffix: choice.text.slice(position),
choiceIndex: choice.index,
};
});
})
.flat()
.reduce((grouped: { [key: string]: { suffix: string; choiceIndex: number }[] }, entry) => {
grouped[entry.prefix] = grouped[entry.prefix] || [];
grouped[entry.prefix].push({ suffix: entry.suffix, choiceIndex: entry.choiceIndex });
return grouped;
}, {});
for (const prefix in entries) {
const cacheKey = {
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
2023-06-06 14:25:31 +00:00
position: key.position + prefix.length,
};
2023-05-29 02:09:44 +00:00
const cacheValue = {
...value,
choices: entries[prefix].map((choice) => {
return {
index: choice.choiceIndex,
text: choice.suffix,
};
}),
};
list.push({
key: cacheKey,
value: cacheValue,
});
}
}
return list;
}
private calculatePartiallyAcceptedPositions(completion: string): number[] {
const positions: number[] = [];
const option = this.options.partiallyAcceptedCacheGeneration;
const lines = splitLines(completion);
let index = 0;
let offset = 0;
// `index < lines.length - 1` to exclude the last line
while (index < lines.length - 1 && index < option.perLine.max) {
offset += lines[index].length;
positions.push(offset);
index++;
}
const words = lines.slice(0, option.perWord.lines).map(splitWords).flat();
index = 0;
offset = 0;
while (index < words.length && index < option.perWord.max) {
offset += words[index].length;
positions.push(offset);
index++;
}
const characters = lines
.slice(0, option.perCharacter.lines)
.map(splitWords)
.flat()
.slice(0, option.perCharacter.words)
.join("");
offset = 1;
while (offset < characters.length && offset < option.perCharacter.max) {
positions.push(offset);
offset++;
}
// distinct and sort ascending
return positions.filter((v, i, arr) => arr.indexOf(v) === i).sort((a, b) => a - b);
}
}