refactor: Update agent getCompletion interface. (#176)

support-coreml
Zhiming Ma 2023-06-02 11:58:34 +08:00 committed by GitHub
parent b4eaf543b1
commit e3eae370be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 1718 additions and 3407 deletions

File diff suppressed because one or more lines are too long

View File

@ -64,18 +64,7 @@ type CompletionEvent = {
choices: Array<Choice>;
};
type CompletionRequest = {
/**
* Language for completion request
*/
language?: string;
/**
* The context to generate completions for, encoded as a string.
*/
prompt: string;
};
type CompletionResponse = {
type CompletionResponse$1 = {
id: string;
created: number;
choices: Array<Choice>;
@ -108,6 +97,13 @@ type HTTPValidationError = {
detail?: Array<ValidationError>;
};
type CompletionRequest = {
filepath: string;
language: string;
text: string;
position: number;
};
type CompletionResponse = CompletionResponse$1;
interface AgentFunction {
setServerUrl(url: string): string;
getServerUrl(): string;
@ -120,7 +116,7 @@ type StatusChangedEvent = {
status: "connecting" | "ready" | "disconnected";
};
type AgentEvent = StatusChangedEvent;
declare const agentEventNames: AgentEvent['event'][];
declare const agentEventNames: AgentEvent["event"][];
interface AgentEventEmitter {
on<T extends AgentEvent>(eventName: T["event"], callback: (event: T) => void): this;
}
@ -135,6 +131,7 @@ declare class TabbyAgent extends EventEmitter implements Agent {
private changeStatus;
private ping;
private wrapApiPromise;
private createPrompt;
setServerUrl(serverUrl: string): string;
getServerUrl(): string;
getStatus(): "connecting" | "ready" | "disconnected";

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -59,7 +59,7 @@ module.exports = __toCommonJS(src_exports);
// src/TabbyAgent.ts
var import_axios2 = __toESM(require("axios"));
var import_events = require("events");
var import_assert = require("assert");
var import_uuid = require("uuid");
// src/CompletionCache.ts
var import_lru_cache = require("lru-cache");
@ -505,6 +505,9 @@ function splitLines(input) {
function splitWords(input) {
return input.match(/\w+|\W+/g).filter(Boolean);
}
function isBlank(input) {
return input.trim().length === 0;
}
function cancelable(promise, cancel) {
return new CancelablePromise((resolve2, reject, onCancel) => {
promise.then((resp) => {
@ -576,7 +579,11 @@ var CompletionCache = class {
return grouped;
}, {});
for (const prefix in entries) {
const cacheKey = { ...key, prompt: key.prompt + prefix };
const cacheKey = {
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + prefix.length
};
const cacheValue = {
...value,
choices: entries[prefix].map((choice) => {
@ -642,8 +649,7 @@ var TabbyAgent = class extends import_events.EventEmitter {
}
async ping(tries = 0) {
try {
const response = await import_axios2.default.get(`${this.serverUrl}/`);
(0, import_assert.strict)(response.status == 200);
const response = await import_axios2.default.get(this.serverUrl);
this.changeStatus("ready");
return true;
} catch (e) {
@ -671,6 +677,14 @@ var TabbyAgent = class extends import_events.EventEmitter {
}
);
}
createPrompt(request2) {
const maxLines = 20;
const prefix = request2.text.slice(0, request2.position);
const lines = splitLines(prefix);
const cutoff = Math.max(lines.length - maxLines, 0);
const prompt = lines.slice(cutoff).join("");
return prompt;
}
setServerUrl(serverUrl) {
this.serverUrl = serverUrl.replace(/\/$/, "");
this.ping();
@ -689,7 +703,20 @@ var TabbyAgent = class extends import_events.EventEmitter {
resolve2(this.completionCache.get(request2));
});
}
const promise = this.wrapApiPromise(this.api.default.completionsV1CompletionsPost(request2));
const prompt = this.createPrompt(request2);
if (isBlank(prompt)) {
return new CancelablePromise((resolve2) => {
resolve2({
id: "agent-" + (0, import_uuid.v4)(),
created: (/* @__PURE__ */ new Date()).getTime(),
choices: []
});
});
}
const promise = this.wrapApiPromise(this.api.default.completionsV1CompletionsPost({
prompt,
language: request2.language
}));
return cancelable(
promise.then((response) => {
this.completionCache.set(request2, response);

File diff suppressed because one or more lines are too long

View File

@ -31,6 +31,7 @@
"form-data": "^4.0.0",
"lru-cache": "^9.1.1",
"object-hash": "^3.0.0",
"object-sizeof": "^2.6.1"
"object-sizeof": "^2.6.1",
"uuid": "^9.0.0"
}
}

View File

@ -1,7 +1,7 @@
import { LRUCache } from "lru-cache";
import hashObject from "object-hash";
import sizeOfObject from "object-sizeof";
import { CompletionRequest, CompletionResponse } from "./generated";
import { CompletionRequest, CompletionResponse } from "./types";
import { splitLines, splitWords } from "./utils";
type CompletionCacheKey = CompletionRequest;
@ -76,7 +76,11 @@ export class CompletionCache {
return grouped;
}, {});
for (const prefix in entries) {
const cacheKey = { ...key, prompt: key.prompt + prefix };
const cacheKey = {
...key,
text: key.text.slice(0, key.position) + prefix + key.text.slice(key.position),
position: key.position + prefix.length
};
const cacheValue = {
...value,
choices: entries[prefix].map((choice) => {

View File

@ -1,18 +1,10 @@
import axios from "axios";
import { EventEmitter } from "events";
import { strict as assert } from "assert";
import { v4 as uuid } from "uuid";
import { CompletionCache } from "./CompletionCache";
import { sleep, cancelable } from "./utils";
import { Agent, AgentEvent } from "./types";
import {
TabbyApi,
CancelablePromise,
ApiError,
CompletionRequest,
CompletionResponse,
ChoiceEvent,
CompletionEvent,
} from "./generated";
import { sleep, cancelable, splitLines, isBlank } from "./utils";
import { Agent, AgentEvent, CompletionRequest, CompletionResponse } from "./types";
import { TabbyApi, CancelablePromise, ApiError, ChoiceEvent, CompletionEvent } from "./generated";
export class TabbyAgent extends EventEmitter implements Agent {
private serverUrl: string = "http://127.0.0.1:5000";
@ -69,6 +61,15 @@ export class TabbyAgent extends EventEmitter implements Agent {
);
}
private createPrompt(request: CompletionRequest): string {
const maxLines = 20;
const prefix = request.text.slice(0, request.position);
const lines = splitLines(prefix);
const cutoff = Math.max(lines.length - maxLines, 0);
const prompt = lines.slice(cutoff).join("");
return prompt;
}
public setServerUrl(serverUrl: string): string {
this.serverUrl = serverUrl.replace(/\/$/, ""); // Remove trailing slash
this.ping();
@ -90,7 +91,21 @@ export class TabbyAgent extends EventEmitter implements Agent {
resolve(this.completionCache.get(request));
});
}
const promise = this.wrapApiPromise(this.api.default.completionsV1CompletionsPost(request));
const prompt = this.createPrompt(request);
if (isBlank(prompt)) {
// Create a empty completion response
return new CancelablePromise((resolve) => {
resolve({
id: "agent-" + uuid(),
created: new Date().getTime(),
choices: []
});
});
}
const promise = this.wrapApiPromise(this.api.default.completionsV1CompletionsPost({
prompt,
language: request.language,
}));
return cancelable(
promise.then((response: CompletionResponse) => {
this.completionCache.set(request, response);

View File

@ -1,13 +1,19 @@
export { TabbyAgent } from "./TabbyAgent";
export { Agent, AgentFunction, AgentEvent, StatusChangedEvent, agentEventNames } from "./types";
export {
Agent,
AgentFunction,
AgentEvent,
StatusChangedEvent,
CompletionRequest,
CompletionResponse,
agentEventNames,
} from "./types";
export {
CancelablePromise,
CancelError,
ApiError,
HTTPValidationError,
ValidationError,
CompletionRequest,
CompletionResponse,
Choice,
ChoiceEvent,
CompletionEvent,

View File

@ -1,4 +1,18 @@
import { CancelablePromise, ChoiceEvent, CompletionEvent, CompletionRequest, CompletionResponse } from "./generated";
import {
CancelablePromise,
ChoiceEvent,
CompletionEvent,
CompletionResponse as ApiCompletionResponse,
} from "./generated";
export type CompletionRequest = {
filepath: string;
language: string;
text: string;
position: number;
};
export type CompletionResponse = ApiCompletionResponse;
export interface AgentFunction {
setServerUrl(url: string): string;
@ -11,10 +25,10 @@ export interface AgentFunction {
export type StatusChangedEvent = {
event: "statusChanged";
status: "connecting" | "ready" | "disconnected";
}
};
export type AgentEvent = StatusChangedEvent;
export const agentEventNames: AgentEvent['event'][] = ["statusChanged"];
export const agentEventNames: AgentEvent["event"][] = ["statusChanged"];
export interface AgentEventEmitter {
on<T extends AgentEvent>(eventName: T["event"], callback: (event: T) => void): this;

View File

@ -10,6 +10,10 @@ export function splitWords(input: string) {
return input.match(/\w+|\W+/g).filter(Boolean); // Split consecutive words and non-words
}
export function isBlank(input: string) {
return input.trim().length === 0;
}
import { CancelablePromise } from "./generated";
export function cancelable<T>(promise: Promise<T>, cancel: () => void): CancelablePromise<T> {
return new CancelablePromise((resolve, reject, onCancel) => {

View File

@ -1140,6 +1140,11 @@ universalify@^2.0.0:
resolved "https://registry.yarnpkg.com/universalify/-/universalify-2.0.0.tgz#75a4984efedc4b08975c5aeb73f530d02df25717"
integrity sha512-hAZsKq7Yy11Zu1DE0OzWjw7nnLZmJZYTDZZyEFHZdUhV8FkH5MCfoU1XMaxXovpyW5nq5scPqq0ZDP9Zyl04oQ==
uuid@^9.0.0:
version "9.0.0"
resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5"
integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==
webidl-conversions@^4.0.2:
version "4.0.2"
resolved "https://registry.yarnpkg.com/webidl-conversions/-/webidl-conversions-4.0.2.tgz#a855980b1f0b6b359ba1d5d9fb39ae941faa63ad"

View File

@ -42,13 +42,6 @@ export class TabbyCompletionProvider implements InlineCompletionItemProvider {
return emptyResponse;
}
const promptRange = this.calculatePromptRange(position);
const prompt = document.getText(promptRange);
if (this.isNil(prompt)) {
console.debug("Prompt is empty, skipping");
return emptyResponse;
}
const currentTimestamp = Date.now();
this.latestTimestamp = currentTimestamp;
@ -59,23 +52,18 @@ export class TabbyCompletionProvider implements InlineCompletionItemProvider {
const replaceRange = this.calculateReplaceRange(document, position);
console.debug(
"Requesting: ",
{
uuid: this.uuid,
timestamp: currentTimestamp,
prompt,
language: document.languageId
}
);
if (this.pendingCompletion) {
this.pendingCompletion.cancel();
}
this.pendingCompletion = this.agent.getCompletions({
prompt: prompt as string, // Prompt is already nil-checked
const request = {
filepath: document.uri.fsPath,
language: document.languageId, // https://code.visualstudio.com/docs/languages/identifiers
});
text: document.getText(),
position: document.offsetAt(position),
};
console.debug("Request: ", request)
this.pendingCompletion = this.agent.getCompletions(request);
const completion = await this.pendingCompletion.catch((_: Error) => {
return null;
@ -93,10 +81,6 @@ export class TabbyCompletionProvider implements InlineCompletionItemProvider {
this.suggestionDelay = configuration.get("suggestionDelay", 150);
}
private isNil(value: string | undefined | null): boolean {
return value === undefined || value === null || value.length === 0;
}
private toInlineCompletions(tabbyCompletion: CompletionResponse | null, range: Range): InlineCompletionItem[] {
return (
tabbyCompletion?.choices?.map((choice: any) => {
@ -121,12 +105,6 @@ export class TabbyCompletionProvider implements InlineCompletionItemProvider {
return ")]}".indexOf(suffix) > -1;
}
private calculatePromptRange(position: Position): Range {
const maxLines = 20;
const firstLine = Math.max(position.line - maxLines, 0);
return new Range(firstLine, 0, position.line, position.character);
}
private calculateReplaceRange(document: TextDocument, position: Position): Range {
const hasSuffixParen = this.hasSuffixParen(document, position);
if (hasSuffixParen) {

View File

@ -2533,6 +2533,7 @@ supports-preserve-symlinks-flag@^1.0.0:
lru-cache "^9.1.1"
object-hash "^3.0.0"
object-sizeof "^2.6.1"
uuid "^9.0.0"
tapable@^2.1.1, tapable@^2.2.0:
version "2.2.1"
@ -2716,6 +2717,11 @@ util@^0.12.0:
is-typed-array "^1.1.3"
which-typed-array "^1.1.2"
uuid@^9.0.0:
version "9.0.0"
resolved "https://registry.yarnpkg.com/uuid/-/uuid-9.0.0.tgz#592f550650024a38ceb0c562f2f6aa435761efb5"
integrity sha512-MXcSTerfPa4uqyzStbRoTgt5XIe3x5+42+q1sDuy3R5MDk66URdLMOZe5aPX/SQd+kuYAh0FdP/pO28IkQyTeg==
vsce@^2.15.0:
version "2.15.0"
resolved "https://registry.npmmirror.com/vsce/-/vsce-2.15.0.tgz#4a992e78532092a34a755143c6b6c2cabcb7d729"