feat: Add agent completion post-process. (#221)

improve-workflow
Zhiming Ma 2023-06-09 01:19:10 +08:00 committed by GitHub
parent 5870f8e868
commit 4ea3298bc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 262 additions and 127 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -59,7 +59,7 @@ module.exports = __toCommonJS(src_exports);
var import_axios2 = __toESM(require("axios"));
var import_events = require("events");
var import_uuid = require("uuid");
var import_deep_equal = __toESM(require("deep-equal"));
var import_deep_equal2 = __toESM(require("deep-equal"));
var import_deepmerge = __toESM(require("deepmerge"));
// src/generated/core/BaseHttpRequest.ts
@ -659,6 +659,41 @@ var CompletionCache = class {
}
};
// src/postprocess.ts
var import_deep_equal = __toESM(require("deep-equal"));
var logger = rootLogger.child({ component: "Postprocess" });
var removeDuplicateLines = (context) => {
return (input) => {
const suffix = context.text.slice(context.position);
const suffixLines = splitLines(suffix);
const inputLines = splitLines(input);
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
if ((0, import_deep_equal.default)(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
return input.slice(0, index);
}
}
return input;
};
};
var dropBlank = (input) => {
return isBlank(input) ? null : input;
};
var applyFilter = (filter) => {
return async (response) => {
response.choices = (await Promise.all(
response.choices.map(async (choice) => {
choice.text = await filter(choice.text);
return choice;
})
)).filter(Boolean);
return response;
};
};
async function postprocess(request2, response) {
return new Promise((resolve2) => resolve2(response)).then(applyFilter(removeDuplicateLines(request2))).then(applyFilter(dropBlank));
}
// src/TabbyAgent.ts
var TabbyAgent = class extends import_events.EventEmitter {
constructor() {
@ -670,7 +705,7 @@ var TabbyAgent = class extends import_events.EventEmitter {
this.onConfigUpdated();
}
onConfigUpdated() {
allLoggers.forEach((logger) => logger.level = this.config.logs.level);
allLoggers.forEach((logger2) => logger2.level = this.config.logs.level);
this.api = new TabbyApi({ BASE: this.config.server.endpoint });
this.ping();
}
@ -735,14 +770,14 @@ var TabbyAgent = class extends import_events.EventEmitter {
this.updateConfig(params.config);
}
if (params.client) {
allLoggers.forEach((logger) => logger.setBindings && logger.setBindings({ client: params.client }));
allLoggers.forEach((logger2) => logger2.setBindings && logger2.setBindings({ client: params.client }));
}
this.logger.debug({ params }, "Initialized");
return true;
}
updateConfig(config) {
const mergedConfig = (0, import_deepmerge.default)(this.config, config);
if (!(0, import_deep_equal.default)(this.config, mergedConfig)) {
if (!(0, import_deep_equal2.default)(this.config, mergedConfig)) {
this.config = mergedConfig;
this.onConfigUpdated();
const event = { event: "configUpdated", config: this.config };
@ -780,6 +815,8 @@ var TabbyAgent = class extends import_events.EventEmitter {
});
return cancelable(
promise.then((response) => {
return postprocess(request2, response);
}).then((response) => {
this.completionCache.set(request2, response);
return response;
}),

File diff suppressed because one or more lines are too long

View File

@ -5950,7 +5950,7 @@ var require_deep_equal = __commonJS({
}
return true;
}
module.exports = function deepEqual3(a7, b5, opts) {
module.exports = function deepEqual4(a7, b5, opts) {
return internalDeepEqual(a7, b5, opts, getSideChannel());
};
}
@ -9189,13 +9189,13 @@ var require_browser2 = __commonJS({
if (opts.enabled === false || opts.browser.disabled)
opts.level = "silent";
const level = opts.level || "info";
const logger = Object.create(proto);
if (!logger.log)
logger.log = noop3;
Object.defineProperty(logger, "levelVal", {
const logger2 = Object.create(proto);
if (!logger2.log)
logger2.log = noop3;
Object.defineProperty(logger2, "levelVal", {
get: getLevelVal
});
Object.defineProperty(logger, "level", {
Object.defineProperty(logger2, "level", {
get: getLevel,
set: setLevel
});
@ -9206,15 +9206,15 @@ var require_browser2 = __commonJS({
levels,
timestamp: getTimeFunction(opts)
};
logger.levels = getLevels(opts);
logger.level = level;
logger.setMaxListeners = logger.getMaxListeners = logger.emit = logger.addListener = logger.on = logger.prependListener = logger.once = logger.prependOnceListener = logger.removeListener = logger.removeAllListeners = logger.listeners = logger.listenerCount = logger.eventNames = logger.write = logger.flush = noop3;
logger.serializers = serializers;
logger._serialize = serialize;
logger._stdErrSerialize = stdErrSerialize;
logger.child = child;
logger2.levels = getLevels(opts);
logger2.level = level;
logger2.setMaxListeners = logger2.getMaxListeners = logger2.emit = logger2.addListener = logger2.on = logger2.prependListener = logger2.once = logger2.prependOnceListener = logger2.removeListener = logger2.removeAllListeners = logger2.listeners = logger2.listenerCount = logger2.eventNames = logger2.write = logger2.flush = noop3;
logger2.serializers = serializers;
logger2._serialize = serialize;
logger2._stdErrSerialize = stdErrSerialize;
logger2.child = child;
if (transmit2)
logger._logEvent = createLogEventShape();
logger2._logEvent = createLogEventShape();
function getLevelVal() {
return this.level === "silent" ? Infinity : this.levels.values[this.level];
}
@ -9226,14 +9226,14 @@ var require_browser2 = __commonJS({
throw Error("unknown level " + level2);
}
this._level = level2;
set(setOpts, logger, "error", "log");
set(setOpts, logger, "fatal", "error");
set(setOpts, logger, "warn", "error");
set(setOpts, logger, "info", "log");
set(setOpts, logger, "debug", "log");
set(setOpts, logger, "trace", "log");
set(setOpts, logger2, "error", "log");
set(setOpts, logger2, "fatal", "error");
set(setOpts, logger2, "warn", "error");
set(setOpts, logger2, "info", "log");
set(setOpts, logger2, "debug", "log");
set(setOpts, logger2, "trace", "log");
customLevels.forEach(function(level3) {
set(setOpts, logger, level3, "log");
set(setOpts, logger2, level3, "log");
});
}
function child(bindings, childOptions) {
@ -9272,7 +9272,7 @@ var require_browser2 = __commonJS({
Child.prototype = this;
return new Child(this);
}
return logger;
return logger2;
}
function getLevels(opts) {
const customLevels = opts.customLevels || {};
@ -9310,15 +9310,15 @@ var require_browser2 = __commonJS({
};
pino2.stdSerializers = stdSerializers;
pino2.stdTimeFunctions = Object.assign({}, { nullTime, epochTime, unixTime, isoTime });
function set(opts, logger, level, fallback) {
const proto = Object.getPrototypeOf(logger);
logger[level] = logger.levelVal > logger.levels.values[level] ? noop3 : proto[level] ? proto[level] : _console[level] || _console[fallback] || noop3;
wrap(opts, logger, level);
function set(opts, logger2, level, fallback) {
const proto = Object.getPrototypeOf(logger2);
logger2[level] = logger2.levelVal > logger2.levels.values[level] ? noop3 : proto[level] ? proto[level] : _console[level] || _console[fallback] || noop3;
wrap(opts, logger2, level);
}
function wrap(opts, logger, level) {
if (!opts.transmit && logger[level] === noop3)
function wrap(opts, logger2, level) {
if (!opts.transmit && logger2[level] === noop3)
return;
logger[level] = function(write2) {
logger2[level] = function(write2) {
return function LOG() {
const ts = opts.timestamp();
const args = new Array(arguments.length);
@ -9333,9 +9333,9 @@ var require_browser2 = __commonJS({
else
write2.apply(proto, args);
if (opts.transmit) {
const transmitLevel = opts.transmit.level || logger.level;
const transmitValue = logger.levels.values[transmitLevel];
const methodValue = logger.levels.values[level];
const transmitLevel = opts.transmit.level || logger2.level;
const transmitValue = logger2.levels.values[transmitLevel];
const methodValue = logger2.levels.values[level];
if (methodValue < transmitValue)
return;
transmit(this, {
@ -9343,25 +9343,25 @@ var require_browser2 = __commonJS({
methodLevel: level,
methodValue,
transmitLevel,
transmitValue: logger.levels.values[opts.transmit.level || logger.level],
transmitValue: logger2.levels.values[opts.transmit.level || logger2.level],
send: opts.transmit.send,
val: logger.levelVal
val: logger2.levelVal
}, args);
}
};
}(logger[level]);
}(logger2[level]);
}
function asObject(logger, level, args, ts) {
if (logger._serialize)
applySerializers(args, logger._serialize, logger.serializers, logger._stdErrSerialize);
function asObject(logger2, level, args, ts) {
if (logger2._serialize)
applySerializers(args, logger2._serialize, logger2.serializers, logger2._stdErrSerialize);
const argsCloned = args.slice();
let msg = argsCloned[0];
const o8 = {};
if (ts) {
o8.time = ts;
}
o8.level = logger.levels.values[level];
let lvl = (logger._childLevel | 0) + 1;
o8.level = logger2.levels.values[level];
let lvl = (logger2._childLevel | 0) + 1;
if (lvl < 1)
lvl = 1;
if (msg !== null && typeof msg === "object") {
@ -9398,27 +9398,27 @@ var require_browser2 = __commonJS({
return parent[level].apply(this, args);
};
}
function transmit(logger, opts, args) {
function transmit(logger2, opts, args) {
const send = opts.send;
const ts = opts.ts;
const methodLevel = opts.methodLevel;
const methodValue = opts.methodValue;
const val = opts.val;
const bindings = logger._logEvent.bindings;
const bindings = logger2._logEvent.bindings;
applySerializers(
args,
logger._serialize || Object.keys(logger.serializers),
logger.serializers,
logger._stdErrSerialize === void 0 ? true : logger._stdErrSerialize
logger2._serialize || Object.keys(logger2.serializers),
logger2.serializers,
logger2._stdErrSerialize === void 0 ? true : logger2._stdErrSerialize
);
logger._logEvent.ts = ts;
logger._logEvent.messages = args.filter(function(arg) {
logger2._logEvent.ts = ts;
logger2._logEvent.messages = args.filter(function(arg) {
return bindings.indexOf(arg) === -1;
});
logger._logEvent.level.label = methodLevel;
logger._logEvent.level.value = methodValue;
send(methodLevel, logger._logEvent, val);
logger._logEvent = createLogEventShape(bindings);
logger2._logEvent.level.label = methodLevel;
logger2._logEvent.level.value = methodValue;
send(methodLevel, logger2._logEvent, val);
logger2._logEvent = createLogEventShape(bindings);
}
function createLogEventShape(bindings) {
return {
@ -12224,7 +12224,7 @@ function v4(options, buf, offset) {
var v4_default = v4;
// src/TabbyAgent.ts
var import_deep_equal = __toESM(require_deep_equal());
var import_deep_equal2 = __toESM(require_deep_equal());
var import_deepmerge = __toESM(require_cjs());
// src/generated/index.ts
@ -32327,6 +32327,46 @@ var CompletionCache = class {
}
};
// src/postprocess.ts
init_global();
init_dirname();
init_filename();
init_buffer2();
init_process2();
var import_deep_equal = __toESM(require_deep_equal());
var logger = rootLogger.child({ component: "Postprocess" });
var removeDuplicateLines = (context) => {
return (input) => {
const suffix = context.text.slice(context.position);
const suffixLines = splitLines(suffix);
const inputLines = splitLines(input);
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
if ((0, import_deep_equal.default)(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
return input.slice(0, index);
}
}
return input;
};
};
var dropBlank = (input) => {
return isBlank(input) ? null : input;
};
var applyFilter = (filter2) => {
return async (response) => {
response.choices = (await Promise.all(
response.choices.map(async (choice) => {
choice.text = await filter2(choice.text);
return choice;
})
)).filter(Boolean);
return response;
};
};
async function postprocess(request2, response) {
return new Promise((resolve4) => resolve4(response)).then(applyFilter(removeDuplicateLines(request2))).then(applyFilter(dropBlank));
}
// src/TabbyAgent.ts
var TabbyAgent = class extends EventEmitter {
constructor() {
@ -32338,7 +32378,7 @@ var TabbyAgent = class extends EventEmitter {
this.onConfigUpdated();
}
onConfigUpdated() {
allLoggers.forEach((logger) => logger.level = this.config.logs.level);
allLoggers.forEach((logger2) => logger2.level = this.config.logs.level);
this.api = new TabbyApi({ BASE: this.config.server.endpoint });
this.ping();
}
@ -32403,14 +32443,14 @@ var TabbyAgent = class extends EventEmitter {
this.updateConfig(params.config);
}
if (params.client) {
allLoggers.forEach((logger) => logger.setBindings && logger.setBindings({ client: params.client }));
allLoggers.forEach((logger2) => logger2.setBindings && logger2.setBindings({ client: params.client }));
}
this.logger.debug({ params }, "Initialized");
return true;
}
updateConfig(config2) {
const mergedConfig = (0, import_deepmerge.default)(this.config, config2);
if (!(0, import_deep_equal.default)(this.config, mergedConfig)) {
if (!(0, import_deep_equal2.default)(this.config, mergedConfig)) {
this.config = mergedConfig;
this.onConfigUpdated();
const event = { event: "configUpdated", config: this.config };
@ -32448,6 +32488,8 @@ var TabbyAgent = class extends EventEmitter {
});
return cancelable(
promise.then((response) => {
return postprocess(request2, response);
}).then((response) => {
this.completionCache.set(request2, response);
return response;
}),

File diff suppressed because one or more lines are too long

View File

@ -8,6 +8,7 @@ import { sleep, cancelable, splitLines, isBlank } from "./utils";
import { Agent, AgentEvent, AgentInitOptions, CompletionRequest, CompletionResponse, LogEventRequest } from "./Agent";
import { AgentConfig, defaultAgentConfig } from "./AgentConfig";
import { CompletionCache } from "./CompletionCache";
import { postprocess } from "./postprocess";
import { rootLogger, allLoggers } from "./logger";
export class TabbyAgent extends EventEmitter implements Agent {
@ -150,10 +151,14 @@ export class TabbyAgent extends EventEmitter implements Agent {
segments,
});
return cancelable(
promise.then((response: CompletionResponse) => {
this.completionCache.set(request, response);
return response;
}),
promise
.then((response) => {
return postprocess(request, response);
})
.then((response) => {
this.completionCache.set(request, response);
return response;
}),
() => {
promise.cancel();
}

View File

@ -0,0 +1,51 @@
import { CompletionRequest, CompletionResponse } from "./Agent";
import deepEqual from "deep-equal";
import { isBlank, splitLines } from "./utils";
import { rootLogger } from "./logger";
type PostprocessContext = CompletionRequest;
type PostprocessFilter = (item: string) => string | null | Promise<string | null>;
const logger = rootLogger.child({ component: "Postprocess" });
const removeDuplicateLines: (context: PostprocessContext) => PostprocessFilter = (context) => {
return (input) => {
const suffix = context.text.slice(context.position);
const suffixLines = splitLines(suffix);
const inputLines = splitLines(input);
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
if (deepEqual(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
return input.slice(0, index);
}
}
return input;
};
};
const dropBlank: PostprocessFilter = (input) => {
return isBlank(input) ? null : input;
};
const applyFilter = (filter: PostprocessFilter) => {
return async (response: CompletionResponse) => {
response.choices = (
await Promise.all(
response.choices.map(async (choice) => {
choice.text = await filter(choice.text);
return choice;
})
)
).filter(Boolean);
return response;
};
};
export async function postprocess(
request: CompletionRequest,
response: CompletionResponse
): Promise<CompletionResponse> {
return new Promise((resolve) => resolve(response))
.then(applyFilter(removeDuplicateLines(request)))
.then(applyFilter(dropBlank));
}