refactor(agent): refactor completion postprocess and caching. (#576)

dedup-snippet-at-index
Zhiming Ma 2023-10-18 00:04:04 +08:00 committed by GitHub
parent 2060d47a95
commit be5e76650f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 418 additions and 339 deletions

View File

@ -10,6 +10,6 @@
"devDependencies": {
"cpy-cli": "^4.2.0",
"rimraf": "^5.0.1",
"tabby-agent": "0.3.1"
"tabby-agent": "0.4.0-dev"
}
}

View File

@ -1,6 +1,6 @@
{
"name": "tabby-agent",
"version": "0.3.1",
"version": "0.4.0-dev",
"description": "Generic client agent for Tabby AI coding assistant IDE extensions.",
"repository": "https://github.com/TabbyML/tabby",
"main": "./dist/index.js",
@ -41,7 +41,6 @@
"jwt-decode": "^3.1.2",
"lru-cache": "^9.1.1",
"object-hash": "^3.0.0",
"object-sizeof": "^2.6.1",
"openapi-fetch": "^0.7.6",
"pino": "^8.14.1",
"rotating-file-stream": "^3.1.0",

View File

@ -1,5 +1,8 @@
import type { components as ApiComponents } from "./types/tabbyApi";
import { AgentConfig, PartialAgentConfig } from "./AgentConfig";
import { CompletionRequest, CompletionResponse, CompletionContext } from "./CompletionContext";
export { CompletionRequest, CompletionResponse, CompletionContext };
export type ClientProperties = Partial<{
user: Record<string, any>;
@ -13,16 +16,6 @@ export type AgentInitOptions = Partial<{
export type ServerHealthState = ApiComponents["schemas"]["HealthState"];
export type CompletionRequest = {
filepath: string;
language: string;
text: string;
position: number;
manually?: boolean;
};
export type CompletionResponse = ApiComponents["schemas"]["CompletionResponse"];
export type LogEventRequest = ApiComponents["schemas"]["LogEventRequest"] & {
select_kind?: "line";
};

View File

@ -1,144 +1,204 @@
import { LRUCache } from "lru-cache";
import hashObject from "object-hash";
import sizeOfObject from "object-sizeof";
import { CompletionRequest, CompletionResponse } from "./Agent";
import { CompletionContext, CompletionResponse } from "./Agent";
import { rootLogger } from "./logger";
import { splitLines, splitWords } from "./utils";
import { splitLines, autoClosingPairOpenings, autoClosingPairClosings, findUnpairedAutoClosingChars } from "./utils";
type CompletionCacheKey = CompletionRequest;
type CompletionCacheKey = CompletionContext;
type CompletionCacheValue = CompletionResponse;
export class CompletionCache {
private readonly logger = rootLogger.child({ component: "CompletionCache" });
private cache: LRUCache<string, CompletionCacheValue>;
private cache: LRUCache<string, { value: CompletionCacheValue; rebuildFlag: boolean }>;
private options = {
maxSize: 1 * 1024 * 1024, // 1MB
partiallyAcceptedCacheGeneration: {
maxCount: 10000,
prebuildCache: {
enabled: true,
perCharacter: {
lines: 1,
words: 10,
max: 30,
},
perWord: {
lines: 1,
max: 20,
max: 50,
},
perLine: {
max: 10,
},
autoClosingPairCheck: {
max: 3,
},
},
};
constructor() {
this.cache = new LRUCache<string, CompletionCacheValue>({
maxSize: this.options.maxSize,
sizeCalculation: sizeOfObject,
this.cache = new LRUCache<string, { value: CompletionCacheValue; rebuildFlag: boolean }>({
max: this.options.maxCount,
});
}
has(key: CompletionCacheKey): boolean {
return this.cache.has(this.hash(key));
return this.cache.has(key.hash);
}
set(key: CompletionCacheKey, value: CompletionCacheValue): void {
for (const entry of this.createCacheEntries(key, value)) {
this.logger.debug({ entry }, "Setting cache entry");
this.cache.set(this.hash(entry.key), entry.value);
}
this.logger.debug({ size: this.cache.calculatedSize }, "Cache size");
buildCache(key: CompletionCacheKey, value: CompletionCacheValue): void {
this.logger.debug({ key, value }, "Starting to build cache");
const entries = this.createCacheEntries(key, value);
entries.forEach((entry) => {
this.cache.set(entry.key.hash, { value: entry.value, rebuildFlag: entry.rebuildFlag });
});
this.logger.debug({ newEntries: entries.length, cacheSize: this.cache.size }, "Cache updated");
}
get(key: CompletionCacheKey): CompletionCacheValue | undefined {
return this.cache.get(this.hash(key));
}
private hash(key: CompletionCacheKey): string {
return hashObject(key);
const entry = this.cache.get(key.hash);
if (entry?.rebuildFlag) {
this.buildCache(key, entry?.value);
}
return entry?.value;
}
private createCacheEntries(
key: CompletionCacheKey,
value: CompletionCacheValue,
): { key: CompletionCacheKey; value: CompletionCacheValue }[] {
const list = [{ key, value }];
if (this.options.partiallyAcceptedCacheGeneration.enabled) {
const entries = value.choices
.map((choice) => {
return this.calculatePartiallyAcceptedPositions(choice.text).map((position) => {
return {
prefix: choice.text.slice(0, position),
suffix: choice.text.slice(position),
choiceIndex: choice.index,
): { key: CompletionCacheKey; value: CompletionCacheValue; rebuildFlag: boolean }[] {
const list = [{ key, value, rebuildFlag: false }];
if (this.options.prebuildCache.enabled) {
for (const choice of value.choices) {
const completionText = choice.text.slice(key.position - choice.replaceRange.start);
const perLinePositions = this.getPerLinePositions(completionText);
this.logger.trace({ completionText, perLinePositions }, "Calculate per-line cache positions");
for (const position of perLinePositions) {
const completionTextPrefix = completionText.slice(0, position);
const completionTextPrefixWithAutoClosedChars = this.generateAutoClosedPrefixes(completionTextPrefix);
for (const prefix of [completionTextPrefix, ...completionTextPrefixWithAutoClosedChars]) {
const entry = {
key: new CompletionContext({
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + position,
}),
value: {
...value,
choices: [
{
index: choice.index,
text: completionText.slice(position),
replaceRange: {
start: key.position + position,
end: key.position + position,
},
},
],
},
rebuildFlag: true,
};
});
})
.flat()
.reduce((grouped: { [key: string]: { suffix: string; choiceIndex: number }[] }, entry) => {
grouped[entry.prefix] = grouped[entry.prefix] || [];
grouped[entry.prefix].push({ suffix: entry.suffix, choiceIndex: entry.choiceIndex });
return grouped;
}, {});
for (const prefix in entries) {
const cacheKey = {
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + prefix.length,
};
const cacheValue = {
...value,
choices: entries[prefix].map((choice) => {
return {
index: choice.choiceIndex,
text: choice.suffix,
this.logger.trace({ prefix, entry }, "Build per-line cache entry");
list.push(entry);
}
}
const perCharacterPositions = this.getPerCharacterPositions(completionText);
this.logger.trace({ completionText, perCharacterPositions }, "Calculate per-character cache positions");
for (const position of perCharacterPositions) {
let lineStart = position;
while (lineStart > 0 && completionText[lineStart - 1] !== "\n") {
lineStart--;
}
const completionTextPrefix = completionText.slice(0, position);
const completionTextPrefixWithAutoClosedChars = this.generateAutoClosedPrefixes(completionTextPrefix);
for (const prefix of [completionTextPrefix, ...completionTextPrefixWithAutoClosedChars]) {
const entry = {
key: new CompletionContext({
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + position,
}),
value: {
...value,
choices: [
{
index: choice.index,
text: completionText.slice(lineStart),
replaceRange: {
start: key.position + lineStart,
end: key.position + position,
},
},
],
},
rebuildFlag: false,
};
}),
};
list.push({
key: cacheKey,
value: cacheValue,
});
this.logger.trace({ prefix, entry }, "Build per-character cache entry");
list.push(entry);
}
}
}
}
return list;
const result = list.reduce((prev, curr) => {
const found = prev.find((entry) => entry.key.hash === curr.key.hash);
if (found) {
found.value.choices.push(...curr.value.choices);
found.rebuildFlag = found.rebuildFlag || curr.rebuildFlag;
} else {
prev.push(curr);
}
return prev;
}, []);
return result;
}
private calculatePartiallyAcceptedPositions(completion: string): number[] {
const positions: number[] = [];
const option = this.options.partiallyAcceptedCacheGeneration;
// positions for every line end (before newline character) and line begin (after indent)
private getPerLinePositions(completion: string): number[] {
const result: number[] = [];
const option = this.options.prebuildCache;
const lines = splitLines(completion);
let index = 0;
let offset = 0;
// `index < lines.length - 1` to exclude the last line
while (index < lines.length - 1 && index < option.perLine.max) {
offset += lines[index].length;
positions.push(offset);
result.push(offset - 1); // cache at the end of the line (before newline character)
result.push(offset); // cache at the beginning of the next line (after newline character)
let offsetNextLineBegin = offset;
while (offsetNextLineBegin < completion.length && completion[offsetNextLineBegin].match(/\s/)) {
offsetNextLineBegin++;
}
result.push(offsetNextLineBegin); // cache at the beginning of the next line (after indent)
index++;
}
return result;
}
const words = lines.slice(0, option.perWord.lines).map(splitWords).flat();
index = 0;
offset = 0;
while (index < words.length && index < option.perWord.max) {
offset += words[index].length;
positions.push(offset);
index++;
}
const characters = lines
.slice(0, option.perCharacter.lines)
.map(splitWords)
.flat()
.slice(0, option.perCharacter.words)
.join("");
offset = 1;
while (offset < characters.length && offset < option.perCharacter.max) {
positions.push(offset);
// positions for every character in the leading lines
private getPerCharacterPositions(completion: string): number[] {
const result: number[] = [];
const option = this.options.prebuildCache;
const text = splitLines(completion).slice(0, option.perCharacter.lines).join("");
let offset = 0;
while (offset < text.length && offset < option.perCharacter.max) {
result.push(offset);
offset++;
}
return result;
}
// distinct and sort ascending
return positions.filter((v, i, arr) => arr.indexOf(v) === i).sort((a, b) => a - b);
// "function(" => ["function()"]
// "call([" => ["call([]", "call([])" ]
// "function(arg" => ["function(arg)" ]
private generateAutoClosedPrefixes(prefix: string): string[] {
const result: string[] = [];
const unpaired = findUnpairedAutoClosingChars(prefix);
for (
let checkIndex = 0, autoClosing = "";
checkIndex < this.options.prebuildCache.autoClosingPairCheck.max;
checkIndex++
) {
if (unpaired.length > checkIndex) {
const found = autoClosingPairOpenings.indexOf(unpaired[unpaired.length - 1 - checkIndex]);
if (found < 0) {
break;
}
autoClosing = autoClosing + autoClosingPairClosings[found];
result.push(prefix + autoClosing);
} else {
break;
}
}
return result;
}
}

View File

@ -0,0 +1,72 @@
import { splitLines, autoClosingPairOpenings, autoClosingPairClosings } from "./utils";
import hashObject from "object-hash";
export type CompletionRequest = {
filepath: string;
language: string;
text: string;
position: number;
manually?: boolean;
};
export type CompletionResponseChoice = {
index: number;
text: string;
// Range of the text to be replaced when applying the completion.
// The range should be limited to the current line.
replaceRange: {
start: number;
end: number;
};
};
export type CompletionResponse = {
id: string;
choices: CompletionResponseChoice[];
};
function isAtLineEndExcludingAutoClosedChar(suffix: string) {
return suffix
.trimEnd()
.split("")
.every((char) => autoClosingPairClosings.includes(char));
}
export class CompletionContext {
filepath: string;
language: string;
text: string;
position: number;
prefix: string;
suffix: string;
prefixLines: string[];
suffixLines: string[];
// "default": the cursor is at the end of the line
// "fill-in-line": the cursor is not at the end of the line, except auto closed characters
// In this case, we assume the completion should be a single line, so multiple lines completion will be dropped.
mode: "default" | "fill-in-line";
hash: string;
constructor(request: CompletionRequest) {
this.filepath = request.filepath;
this.language = request.language;
this.text = request.text;
this.position = request.position;
this.prefix = request.text.slice(0, request.position);
this.suffix = request.text.slice(request.position);
this.prefixLines = splitLines(this.prefix);
this.suffixLines = splitLines(this.suffix);
const lineEnd = isAtLineEndExcludingAutoClosedChar(this.suffixLines[0] ?? "");
this.mode = lineEnd ? "default" : "fill-in-line";
this.hash = hashObject({
filepath: request.filepath,
language: request.language,
text: request.text,
position: request.position,
});
}
}

View File

@ -5,7 +5,14 @@ import { deepmerge } from "deepmerge-ts";
import { getProperty, setProperty, deleteProperty } from "dot-prop";
import createClient from "openapi-fetch";
import { paths as TabbyApi } from "./types/tabbyApi";
import { splitLines, isBlank, abortSignalFromAnyOf, HttpError, isTimeoutError, isCanceledError } from "./utils";
import {
isBlank,
abortSignalFromAnyOf,
findUnpairedAutoClosingChars,
HttpError,
isTimeoutError,
isCanceledError,
} from "./utils";
import type {
Agent,
AgentStatus,
@ -23,8 +30,9 @@ import { Auth } from "./Auth";
import { AgentConfig, PartialAgentConfig, defaultAgentConfig, userAgentConfig } from "./AgentConfig";
import { CompletionCache } from "./CompletionCache";
import { CompletionDebounce } from "./CompletionDebounce";
import { CompletionContext } from "./CompletionContext";
import { DataStore } from "./dataStore";
import { postprocess, preCacheProcess } from "./postprocess";
import { preCacheProcess, postCacheProcess } from "./postprocess";
import { rootLogger, allLoggers } from "./logger";
import { AnonymousUsageLogger } from "./AnonymousUsageLogger";
import { CompletionProviderStats, CompletionProviderStatsEntry } from "./CompletionProviderStats";
@ -250,20 +258,46 @@ export class TabbyAgent extends EventEmitter implements Agent {
}
}
private createSegments(request: CompletionRequest): { prefix: string; suffix: string } {
private createSegments(context: CompletionContext): { prefix: string; suffix: string } {
// max lines in prefix and suffix configurable
const maxPrefixLines = this.config.completion.prompt.maxPrefixLines;
const maxSuffixLines = this.config.completion.prompt.maxSuffixLines;
const prefix = request.text.slice(0, request.position);
const prefixLines = splitLines(prefix);
const suffix = request.text.slice(request.position);
const suffixLines = splitLines(suffix);
const { prefixLines, suffixLines } = context;
return {
prefix: prefixLines.slice(Math.max(prefixLines.length - maxPrefixLines, 0)).join(""),
suffix: suffixLines.slice(0, maxSuffixLines).join(""),
};
}
private calculateReplaceRange(response: CompletionResponse, context: CompletionContext): CompletionResponse {
const { suffixLines } = context;
const suffixText = suffixLines[0]?.trimEnd() || "";
if (isBlank(suffixText)) {
return response;
}
for (const choice of response.choices) {
const completionText = choice.text.slice(context.position - choice.replaceRange.start);
const unpaired = findUnpairedAutoClosingChars(completionText);
if (isBlank(unpaired)) {
continue;
}
if (suffixText.startsWith(unpaired)) {
choice.replaceRange.end = context.position + unpaired.length;
this.logger.trace(
{ context, completion: choice.text, range: choice.replaceRange, unpaired },
"Adjust replace range",
);
} else if (unpaired.startsWith(suffixText)) {
choice.replaceRange.end = context.position + suffixText.length;
this.logger.trace(
{ context, completion: choice.text, range: choice.replaceRange, unpaired },
"Adjust replace range",
);
}
}
return response;
}
public async initialize(options: AgentInitOptions): Promise<boolean> {
if (options.clientProperties) {
const { user: userProp, session: sessionProp } = options.clientProperties;
@ -397,6 +431,7 @@ export class TabbyAgent extends EventEmitter implements Agent {
if (this.status === "notInitialized") {
throw new Error("Agent is not initialized");
}
this.logger.trace({ request }, "Call provideCompletions");
if (this.nonParallelProvideCompletionAbortController) {
this.nonParallelProvideCompletionAbortController.abort();
}
@ -415,11 +450,12 @@ export class TabbyAgent extends EventEmitter implements Agent {
};
let requestStartedAt: number | null = null;
const context = new CompletionContext(request);
try {
if (this.completionCache.has(request)) {
if (this.completionCache.has(context)) {
// Cache hit
stats.cacheHit = true;
this.logger.debug({ request }, "Completion cache hit");
this.logger.debug({ context }, "Completion cache hit");
// Debounce before returning cached response
await this.completionDebounce.debounce(
{
@ -429,11 +465,11 @@ export class TabbyAgent extends EventEmitter implements Agent {
},
{ signal },
);
completionResponse = this.completionCache.get(request);
completionResponse = this.completionCache.get(context);
} else {
// Cache miss
stats.cacheHit = false;
const segments = this.createSegments(request);
const segments = this.createSegments(context);
if (isBlank(segments.prefix)) {
// Empty prompt
stats = null; // no need to record stats for empty prompt
@ -457,7 +493,7 @@ export class TabbyAgent extends EventEmitter implements Agent {
stats.requestSent = true;
requestStartedAt = performance.now();
try {
completionResponse = await this.post(
const response = await this.post(
"/v1/completions",
{
body: {
@ -474,6 +510,19 @@ export class TabbyAgent extends EventEmitter implements Agent {
},
);
stats.requestLatency = performance.now() - requestStartedAt;
completionResponse = {
id: response.id,
choices: response.choices.map((choice) => {
return {
index: choice.index,
text: choice.text,
replaceRange: {
start: request.position,
end: request.position,
},
};
}),
};
} catch (error) {
if (isCanceledError(error)) {
stats.requestCanceled = true;
@ -487,19 +536,21 @@ export class TabbyAgent extends EventEmitter implements Agent {
throw error;
}
// Postprocess (pre-cache)
completionResponse = await preCacheProcess(request, completionResponse);
completionResponse = await preCacheProcess(context, completionResponse);
if (options?.signal?.aborted) {
throw options.signal.reason;
}
// Build cache
this.completionCache.set(request, completionResponse);
this.completionCache.buildCache(context, completionResponse);
}
}
// Postprocess (post-cache)
completionResponse = await postprocess(request, completionResponse);
completionResponse = await postCacheProcess(context, completionResponse);
if (options?.signal?.aborted) {
throw options.signal.reason;
}
// Calculate replace range
completionResponse = this.calculateReplaceRange(completionResponse, context);
} catch (error) {
if (isCanceledError(error) || isTimeoutError(error)) {
if (stats) {
@ -535,7 +586,7 @@ export class TabbyAgent extends EventEmitter implements Agent {
}
}
}
this.logger.trace({ context, completionResponse }, "Return from provideCompletions");
return completionResponse;
}

View File

@ -1,32 +1,10 @@
import { CompletionRequest, CompletionResponse } from "../Agent";
import { splitLines } from "../utils";
import { CompletionResponse, CompletionContext } from "../CompletionContext";
import { rootLogger } from "../logger";
export type PostprocessContext = {
request: CompletionRequest; // request contains full context, others are for easy access
prefix: string;
suffix: string;
prefixLines: string[];
suffixLines: string[];
};
export type PostprocessFilter = (item: string) => string | null | Promise<string | null>;
export const logger = rootLogger.child({ component: "Postprocess" });
export function buildContext(request: CompletionRequest): PostprocessContext {
const prefix = request.text.slice(0, request.position);
const suffix = request.text.slice(request.position);
const prefixLines = splitLines(prefix);
const suffixLines = splitLines(suffix);
return {
request,
prefix,
suffix,
prefixLines,
suffixLines,
};
}
declare global {
interface Array<T> {
distinct(identity?: (x: T) => any): Array<T>;
@ -39,12 +17,17 @@ if (!Array.prototype.distinct) {
};
}
export function applyFilter(filter: PostprocessFilter): (response: CompletionResponse) => Promise<CompletionResponse> {
export function applyFilter(
filter: PostprocessFilter,
context: CompletionContext,
): (response: CompletionResponse) => Promise<CompletionResponse> {
return async (response: CompletionResponse) => {
response.choices = (
await Promise.all(
response.choices.map(async (choice) => {
choice.text = await filter(choice.text);
const replaceLength = context.position - choice.replaceRange.start;
const filtered = await filter(choice.text.slice(replaceLength));
choice.text = choice.text.slice(0, replaceLength) + (filtered ?? "");
return choice;
}),
)

View File

@ -36,27 +36,5 @@ describe("postprocess", () => {
`;
expect(dropDuplicated(context)(completion)).to.be.null;
});
it("should drop completion that first 3 lines are similar to suffix", () => {
const context = {
...documentContext`
var a, b;
// swap a and b║
let z = a;
a = b;
b = z;
// something else
`,
language: "javascript",
};
const completion = inline`
let c = a;
a = b;
b = c;
console.log({a, b});
`;
expect(dropDuplicated(context)(completion)).to.be.null;
});
});
});

View File

@ -1,7 +1,8 @@
import { PostprocessFilter, PostprocessContext, logger } from "./base";
import { CompletionContext } from "../Agent";
import { PostprocessFilter, logger } from "./base";
import { splitLines, isBlank, calcDistance } from "../utils";
export const dropDuplicated: (context: PostprocessContext) => PostprocessFilter = (context) => {
export const dropDuplicated: (context: CompletionContext) => PostprocessFilter = (context) => {
return (input) => {
// get first n (n <= 3) lines of input and suffix, ignore blank lines
const { suffixLines } = context;
@ -24,9 +25,9 @@ export const dropDuplicated: (context: PostprocessContext) => PostprocessFilter
.slice(suffixIndex, suffixIndex + lineCount)
.join("")
.trim();
// if string distance is less than threshold (threshold = 3, or 5% of string length)
// if string distance is less than threshold (threshold = 1, or 5% of string length)
// drop this completion due to duplicated
const threshold = Math.max(3, 0.05 * inputToCompare.length, 0.05 * suffixToCompare.length);
const threshold = Math.max(1, 0.05 * inputToCompare.length, 0.05 * suffixToCompare.length);
const distance = calcDistance(inputToCompare, suffixToCompare);
if (distance <= threshold) {
logger.debug(

View File

@ -1,37 +1,33 @@
import { CompletionRequest, CompletionResponse } from "../Agent";
import { buildContext, applyFilter } from "./base";
import { CompletionContext, CompletionResponse } from "../Agent";
import { applyFilter } from "./base";
import { removeRepetitiveBlocks } from "./removeRepetitiveBlocks";
import { removeRepetitiveLines } from "./removeRepetitiveLines";
import { removeLineEndsWithRepetition } from "./removeLineEndsWithRepetition";
import { limitScopeByIndentation } from "./limitScopeByIndentation";
import { trimSpace } from "./trimSpace";
import { removeOverlapping } from "./removeOverlapping";
import { dropDuplicated } from "./dropDuplicated";
import { dropBlank } from "./dropBlank";
export async function preCacheProcess(
request: CompletionRequest,
context: CompletionContext,
response: CompletionResponse,
): Promise<CompletionResponse> {
const context = buildContext(request);
return Promise.resolve(response)
.then(applyFilter(removeLineEndsWithRepetition(context)))
.then(applyFilter(dropDuplicated(context)))
.then(applyFilter(trimSpace(context)))
.then(applyFilter(removeOverlapping(context)))
.then(applyFilter(dropBlank()));
.then(applyFilter(removeLineEndsWithRepetition(context), context))
.then(applyFilter(dropDuplicated(context), context))
.then(applyFilter(trimSpace(context), context))
.then(applyFilter(dropBlank(), context));
}
export async function postprocess(
request: CompletionRequest,
export async function postCacheProcess(
context: CompletionContext,
response: CompletionResponse,
): Promise<CompletionResponse> {
const context = buildContext(request);
return Promise.resolve(response)
.then(applyFilter(removeRepetitiveBlocks(context)))
.then(applyFilter(removeRepetitiveLines(context)))
.then(applyFilter(limitScopeByIndentation(context)))
.then(applyFilter(trimSpace(context)))
.then(applyFilter(removeOverlapping(context)))
.then(applyFilter(dropBlank()));
.then(applyFilter(removeRepetitiveBlocks(context), context))
.then(applyFilter(removeRepetitiveLines(context), context))
.then(applyFilter(limitScopeByIndentation(context), context))
.then(applyFilter(dropDuplicated(context), context))
.then(applyFilter(trimSpace(context), context))
.then(applyFilter(dropBlank(), context));
}

View File

@ -1,4 +1,5 @@
import { PostprocessFilter, PostprocessContext, logger } from "./base";
import { CompletionContext } from "../Agent";
import { PostprocessFilter, logger } from "./base";
import { isBlank, splitLines } from "../utils";
function calcIndentLevel(line: string): number {
@ -12,11 +13,6 @@ function isOpeningIndentBlock(lines, index) {
return calcIndentLevel(lines[index]) < calcIndentLevel(lines[index + 1]);
}
function shouldOnlyAllowSingleLine(suffixLines: string[]): boolean {
let currentLineInSuffix = suffixLines[0] ?? "";
return !isBlank(currentLineInSuffix.replace(/[\)\}\]]/g, ""));
}
function processContext(
lines: string[],
prefixLines: string[],
@ -81,7 +77,7 @@ function processContext(
}
// check if suffix context allows closing line
// skip 0 that is current line in suffix, it is processed in `shouldOnlyAllowSingleLine`
// skip 0 that is current line in suffix
let firstNonBlankLineInSuffix = 1;
while (firstNonBlankLineInSuffix < suffixLines.length && isBlank(suffixLines[firstNonBlankLineInSuffix])) {
firstNonBlankLineInSuffix++;
@ -92,11 +88,11 @@ function processContext(
return result;
}
export const limitScopeByIndentation: (context: PostprocessContext) => PostprocessFilter = (context) => {
export const limitScopeByIndentation: (context: CompletionContext) => PostprocessFilter = (context) => {
return (input) => {
const { prefix, suffix, prefixLines, suffixLines } = context;
const inputLines = splitLines(input);
if (shouldOnlyAllowSingleLine(suffixLines)) {
if (context.mode === "fill-in-line") {
if (inputLines.length > 1) {
logger.debug({ input, prefix, suffix }, "Drop content with multiple lines");
return null;
@ -116,8 +112,11 @@ export const limitScopeByIndentation: (context: PostprocessContext) => Postproce
continue;
}
// We include this closing line here if context allows
// Python does not have closing bracket, so we always include closing line
if (indentContext.allowClosingLine && context.request.language !== "python") {
// For python, if previous line is blank, we don't include this line
if (
(context.language !== "python" && indentContext.allowClosingLine) ||
(context.language === "python" && indentContext.allowClosingLine && !isBlank(inputLines[index - 1]))
) {
index++;
}
break;

View File

@ -1,4 +1,5 @@
import { PostprocessFilter, PostprocessContext, logger } from "./base";
import { CompletionContext } from "../Agent";
import { PostprocessFilter, logger } from "./base";
import { splitLines, isBlank } from "../utils";
const repetitionTests = [
@ -6,7 +7,7 @@ const repetitionTests = [
/(.{10,}?)\1{3,}$/g, // match a 10+ characters pattern repeating 3+ times
];
export const removeLineEndsWithRepetition: (context: PostprocessContext) => PostprocessFilter = () => {
export const removeLineEndsWithRepetition: (context: CompletionContext) => PostprocessFilter = () => {
return (input) => {
// only test last non-blank line
const inputLines = splitLines(input);

View File

@ -1,65 +0,0 @@
import { expect } from "chai";
import { documentContext, inline } from "./testUtils";
import { removeOverlapping } from "./removeOverlapping";
describe("postprocess", () => {
describe("removeOverlapping", () => {
it("should remove content overlapped between completion and suffix", () => {
const context = {
...documentContext`
function sum(a, b) {
return value;
}
`,
language: "javascript",
};
const completion = inline`
let value = a + b;
return value;
}
`;
const expected = inline`
let value = a + b;
`;
expect(removeOverlapping(context)(completion)).to.eq(expected);
});
// Bad case
it("can not remove text that suffix not exactly starts with", () => {
const context = {
...documentContext`
let sum = (a, b) => {
return a + b;
}
`,
language: "javascript",
};
// completion give a `;` at end but context have not
const completion = inline`
return a + b;
};
`;
expect(removeOverlapping(context)(completion)).to.eq(completion);
});
// Bad case
it("can not remove text that suffix not exactly starts with", () => {
const context = {
...documentContext`
let sum = (a, b) => {
return a + b;
}
`,
language: "javascript",
};
// the difference is a `\n`
const completion = inline`
}
`;
expect(removeOverlapping(context)(completion)).to.eq(completion);
});
});
});

View File

@ -1,15 +0,0 @@
import { PostprocessFilter, PostprocessContext, logger } from "./base";
export const removeOverlapping: (context: PostprocessContext) => PostprocessFilter = (context) => {
return (input) => {
const request = context.request;
const suffix = request.text.slice(request.position);
for (let index = Math.max(0, input.length - suffix.length); index < input.length; index++) {
if (input.slice(index) === suffix.slice(0, input.length - index)) {
logger.debug({ input, suffix, overlappedAt: index }, "Remove overlapped content");
return input.slice(0, index);
}
}
return input;
};
};

View File

@ -1,4 +1,5 @@
import { PostprocessFilter, PostprocessContext, logger } from "./base";
import { CompletionContext } from "../Agent";
import { PostprocessFilter, logger } from "./base";
import { isBlank, calcDistance } from "../utils";
function blockSplitter(language) {
@ -8,9 +9,9 @@ function blockSplitter(language) {
}
// FIXME: refactor this because it is very similar to `removeRepetitiveLines`
export const removeRepetitiveBlocks: (context: PostprocessContext) => PostprocessFilter = (context) => {
export const removeRepetitiveBlocks: (context: CompletionContext) => PostprocessFilter = (context) => {
return (input) => {
const inputBlocks = input.split(blockSplitter(context.request.language));
const inputBlocks = input.split(blockSplitter(context.language));
let repetitionCount = 0;
const repetitionThreshold = 2;
// skip last block, it maybe cut
@ -25,10 +26,10 @@ export const removeRepetitiveBlocks: (context: PostprocessContext) => Postproces
prev--;
}
if (prev < 0) break;
// if distance between current and previous block is less than threshold (threshold = 3, or 10% of string length)
// if distance between current and previous block is less than threshold (threshold = or 10% of string length)
const currentBlock = inputBlocks[index].trim();
const previousBlock = inputBlocks[prev].trim();
const threshold = Math.max(3, 0.1 * currentBlock.length, 0.1 * previousBlock.length);
const threshold = Math.max(0.1 * currentBlock.length, 0.1 * previousBlock.length);
const distance = calcDistance(currentBlock, previousBlock);
if (distance <= threshold) {
repetitionCount++;

View File

@ -1,7 +1,8 @@
import { PostprocessFilter, PostprocessContext, logger } from "./base";
import { CompletionContext } from "../Agent";
import { PostprocessFilter, logger } from "./base";
import { splitLines, isBlank, calcDistance } from "../utils";
export const removeRepetitiveLines: (context: PostprocessContext) => PostprocessFilter = () => {
export const removeRepetitiveLines: (context: CompletionContext) => PostprocessFilter = () => {
return (input) => {
const inputLines = splitLines(input);
let repetitionCount = 0;
@ -18,10 +19,10 @@ export const removeRepetitiveLines: (context: PostprocessContext) => Postprocess
prev--;
}
if (prev < 0) break;
// if distance between current and previous line is less than threshold (threshold = 3, or 10% of string length)
// if distance between current and previous line is less than threshold (threshold = or 10% of string length)
const currentLine = inputLines[index].trim();
const previousLine = inputLines[prev].trim();
const threshold = Math.max(3, 0.1 * currentLine.length, 0.1 * previousLine.length);
const threshold = Math.max(0.1 * currentLine.length, 0.1 * previousLine.length);
const distance = calcDistance(currentLine, previousLine);
if (distance <= threshold) {
repetitionCount++;

View File

@ -1,11 +1,10 @@
import dedent from "dedent";
import { buildContext, PostprocessContext } from "./base";
import { CompletionContext } from "../CompletionContext";
// `║` is the cursor position
export function documentContext(strings): PostprocessContext {
export function documentContext(strings): CompletionContext {
const doc = dedent(strings);
return buildContext({
return new CompletionContext({
filepath: null,
language: null,
text: doc.replace(/║/, ""),

View File

@ -1,7 +1,8 @@
import { PostprocessFilter, PostprocessContext } from "./base";
import { CompletionContext } from "../Agent";
import { PostprocessFilter, logger } from "./base";
import { splitLines, isBlank } from "../utils";
export const trimSpace: (context: PostprocessContext) => PostprocessFilter = (context) => {
export const trimSpace: (context: CompletionContext) => PostprocessFilter = (context) => {
return (input) => {
const { prefixLines, suffixLines } = context;
const inputLines = splitLines(input);

View File

@ -15,6 +15,52 @@ export function isBlank(input: string) {
return input.trim().length === 0;
}
export const autoClosingPairs = [
["(", ")"],
["[", "]"],
["{", "}"],
["'", "'"],
['"', '"'],
["`", "`"],
];
export const autoClosingPairOpenings = autoClosingPairs.map((pair) => pair[0]);
export const autoClosingPairClosings = autoClosingPairs.map((pair) => pair[1]);
// FIXME: This function is not good enough, it can not handle escaped characters.
export function findUnpairedAutoClosingChars(input: string): string {
const stack: string[] = [];
for (const char of input) {
[
["(", ")"],
["[", "]"],
["{", "}"],
].forEach((pair) => {
if (char === pair[1]) {
if (stack.length > 0 && stack[stack.length - 1] === pair[0]) {
stack.pop();
} else {
stack.push(char);
}
}
});
if ("([{".includes(char)) {
stack.push(char);
}
["'", '"', "`"].forEach((quote) => {
if (char === quote) {
if (stack.length > 0 && stack.includes(quote)) {
stack.splice(stack.lastIndexOf(quote), stack.length - stack.lastIndexOf(quote));
} else {
stack.push(char);
}
}
});
}
return stack.join("");
}
// Using string levenshtein distance is not good, because variable name may create a large distance.
// Such as distance is 9 between `const fooFooFoo = 1;` and `const barBarBar = 1;`, but maybe 1 is enough.
// May be better to count distance based on words instead of characters.

View File

@ -10,6 +10,6 @@
"devDependencies": {
"cpy-cli": "^4.2.0",
"rimraf": "^5.0.1",
"tabby-agent": "0.3.1"
"tabby-agent": "0.4.0-dev"
}
}

View File

@ -217,6 +217,6 @@
},
"dependencies": {
"@xstate/fsm": "^2.0.1",
"tabby-agent": "0.3.1"
"tabby-agent": "0.4.0-dev"
}
}

View File

@ -51,11 +51,10 @@ export class TabbyCompletionProvider extends EventEmitter implements InlineCompl
context: InlineCompletionContext,
token: CancellationToken,
): Promise<InlineCompletionItem[] | null> {
if (context.triggerKind === InlineCompletionTriggerKind.Automatic && this.triggerMode === "manual") {
return null;
}
console.debug("Call provideInlineCompletionItems.");
if (context.triggerKind === InlineCompletionTriggerKind.Invoke && this.triggerMode === "automatic") {
if (context.triggerKind === InlineCompletionTriggerKind.Automatic && this.triggerMode === "manual") {
console.debug("Skip automatic trigger when triggerMode is manual.");
return null;
}
@ -70,8 +69,6 @@ export class TabbyCompletionProvider extends EventEmitter implements InlineCompl
return null;
}
const replaceRange = this.calculateReplaceRange(document, position);
const request: CompletionRequest = {
filepath: document.uri.fsPath,
language: document.languageId, // https://code.visualstudio.com/docs/languages/identifiers
@ -104,19 +101,24 @@ export class TabbyCompletionProvider extends EventEmitter implements InlineCompl
// Assume only one choice is provided, do not support multiple choices for now
if (result.choices.length > 0) {
this.latestCompletions = result;
const choice = result.choices[0];
this.postEvent("show");
return [
new InlineCompletionItem(result.choices[0].text, replaceRange, {
title: "",
command: "tabby.applyCallback",
arguments: [
() => {
this.postEvent("accept");
},
],
}),
new InlineCompletionItem(
choice.text,
new Range(document.positionAt(choice.replaceRange.start), document.positionAt(choice.replaceRange.end)),
{
title: "",
command: "tabby.applyCallback",
arguments: [
() => {
this.postEvent("accept");
},
],
},
),
];
}
} catch (error: any) {
@ -171,21 +173,4 @@ export class TabbyCompletionProvider extends EventEmitter implements InlineCompl
this.emit("triggerModeUpdated");
}
}
private hasSuffixParen(document: TextDocument, position: Position) {
const suffix = document.getText(
new Range(position.line, position.character, position.line, position.character + 1),
);
return ")]}".indexOf(suffix) > -1;
}
// FIXME: move replace range calculation to tabby-agent
private calculateReplaceRange(document: TextDocument, position: Position): Range {
const hasSuffixParen = this.hasSuffixParen(document, position);
if (hasSuffixParen) {
return new Range(position.line, position.character, position.line, position.character + 1);
} else {
return new Range(position, position);
}
}
}

View File

@ -2944,13 +2944,6 @@ object-keys@^1.1.1:
resolved "https://registry.yarnpkg.com/object-keys/-/object-keys-1.1.1.tgz#1c47f272df277f3b1daf061677d9c82e2322c60e"
integrity sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==
object-sizeof@^2.6.1:
version "2.6.3"
resolved "https://registry.yarnpkg.com/object-sizeof/-/object-sizeof-2.6.3.tgz#3e106c15d90b13664cb8f387c66eb162fcbef1d8"
integrity sha512-GNkVRrLh11Qr5BGr73dwwPE200/78QG2rbx30cnXPnMvt7UuttH4Dup5t+LtcQhARkg8Hbr0c8Kiz52+CFxYmw==
dependencies:
buffer "^6.0.3"
object.assign@^4.1.4:
version "4.1.4"
resolved "https://registry.yarnpkg.com/object.assign/-/object.assign-4.1.4.tgz#9673c7c7c351ab8c4d0b516f4343ebf4dfb7799f"