feat: Add agent completion post-process. (#221)

improve-workflow
Zhiming Ma 2023-06-09 01:19:10 +08:00 committed by GitHub
parent 5870f8e868
commit 4ea3298bc9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 262 additions and 127 deletions

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -59,7 +59,7 @@ module.exports = __toCommonJS(src_exports);
var import_axios2 = __toESM(require("axios")); var import_axios2 = __toESM(require("axios"));
var import_events = require("events"); var import_events = require("events");
var import_uuid = require("uuid"); var import_uuid = require("uuid");
var import_deep_equal = __toESM(require("deep-equal")); var import_deep_equal2 = __toESM(require("deep-equal"));
var import_deepmerge = __toESM(require("deepmerge")); var import_deepmerge = __toESM(require("deepmerge"));
// src/generated/core/BaseHttpRequest.ts // src/generated/core/BaseHttpRequest.ts
@ -659,6 +659,41 @@ var CompletionCache = class {
} }
}; };
// src/postprocess.ts
var import_deep_equal = __toESM(require("deep-equal"));
var logger = rootLogger.child({ component: "Postprocess" });
var removeDuplicateLines = (context) => {
return (input) => {
const suffix = context.text.slice(context.position);
const suffixLines = splitLines(suffix);
const inputLines = splitLines(input);
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
if ((0, import_deep_equal.default)(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
return input.slice(0, index);
}
}
return input;
};
};
var dropBlank = (input) => {
return isBlank(input) ? null : input;
};
var applyFilter = (filter) => {
return async (response) => {
response.choices = (await Promise.all(
response.choices.map(async (choice) => {
choice.text = await filter(choice.text);
return choice;
})
)).filter(Boolean);
return response;
};
};
async function postprocess(request2, response) {
return new Promise((resolve2) => resolve2(response)).then(applyFilter(removeDuplicateLines(request2))).then(applyFilter(dropBlank));
}
// src/TabbyAgent.ts // src/TabbyAgent.ts
var TabbyAgent = class extends import_events.EventEmitter { var TabbyAgent = class extends import_events.EventEmitter {
constructor() { constructor() {
@ -670,7 +705,7 @@ var TabbyAgent = class extends import_events.EventEmitter {
this.onConfigUpdated(); this.onConfigUpdated();
} }
onConfigUpdated() { onConfigUpdated() {
allLoggers.forEach((logger) => logger.level = this.config.logs.level); allLoggers.forEach((logger2) => logger2.level = this.config.logs.level);
this.api = new TabbyApi({ BASE: this.config.server.endpoint }); this.api = new TabbyApi({ BASE: this.config.server.endpoint });
this.ping(); this.ping();
} }
@ -735,14 +770,14 @@ var TabbyAgent = class extends import_events.EventEmitter {
this.updateConfig(params.config); this.updateConfig(params.config);
} }
if (params.client) { if (params.client) {
allLoggers.forEach((logger) => logger.setBindings && logger.setBindings({ client: params.client })); allLoggers.forEach((logger2) => logger2.setBindings && logger2.setBindings({ client: params.client }));
} }
this.logger.debug({ params }, "Initialized"); this.logger.debug({ params }, "Initialized");
return true; return true;
} }
updateConfig(config) { updateConfig(config) {
const mergedConfig = (0, import_deepmerge.default)(this.config, config); const mergedConfig = (0, import_deepmerge.default)(this.config, config);
if (!(0, import_deep_equal.default)(this.config, mergedConfig)) { if (!(0, import_deep_equal2.default)(this.config, mergedConfig)) {
this.config = mergedConfig; this.config = mergedConfig;
this.onConfigUpdated(); this.onConfigUpdated();
const event = { event: "configUpdated", config: this.config }; const event = { event: "configUpdated", config: this.config };
@ -780,6 +815,8 @@ var TabbyAgent = class extends import_events.EventEmitter {
}); });
return cancelable( return cancelable(
promise.then((response) => { promise.then((response) => {
return postprocess(request2, response);
}).then((response) => {
this.completionCache.set(request2, response); this.completionCache.set(request2, response);
return response; return response;
}), }),

File diff suppressed because one or more lines are too long

View File

@ -5950,7 +5950,7 @@ var require_deep_equal = __commonJS({
} }
return true; return true;
} }
module.exports = function deepEqual3(a7, b5, opts) { module.exports = function deepEqual4(a7, b5, opts) {
return internalDeepEqual(a7, b5, opts, getSideChannel()); return internalDeepEqual(a7, b5, opts, getSideChannel());
}; };
} }
@ -9189,13 +9189,13 @@ var require_browser2 = __commonJS({
if (opts.enabled === false || opts.browser.disabled) if (opts.enabled === false || opts.browser.disabled)
opts.level = "silent"; opts.level = "silent";
const level = opts.level || "info"; const level = opts.level || "info";
const logger = Object.create(proto); const logger2 = Object.create(proto);
if (!logger.log) if (!logger2.log)
logger.log = noop3; logger2.log = noop3;
Object.defineProperty(logger, "levelVal", { Object.defineProperty(logger2, "levelVal", {
get: getLevelVal get: getLevelVal
}); });
Object.defineProperty(logger, "level", { Object.defineProperty(logger2, "level", {
get: getLevel, get: getLevel,
set: setLevel set: setLevel
}); });
@ -9206,15 +9206,15 @@ var require_browser2 = __commonJS({
levels, levels,
timestamp: getTimeFunction(opts) timestamp: getTimeFunction(opts)
}; };
logger.levels = getLevels(opts); logger2.levels = getLevels(opts);
logger.level = level; logger2.level = level;
logger.setMaxListeners = logger.getMaxListeners = logger.emit = logger.addListener = logger.on = logger.prependListener = logger.once = logger.prependOnceListener = logger.removeListener = logger.removeAllListeners = logger.listeners = logger.listenerCount = logger.eventNames = logger.write = logger.flush = noop3; logger2.setMaxListeners = logger2.getMaxListeners = logger2.emit = logger2.addListener = logger2.on = logger2.prependListener = logger2.once = logger2.prependOnceListener = logger2.removeListener = logger2.removeAllListeners = logger2.listeners = logger2.listenerCount = logger2.eventNames = logger2.write = logger2.flush = noop3;
logger.serializers = serializers; logger2.serializers = serializers;
logger._serialize = serialize; logger2._serialize = serialize;
logger._stdErrSerialize = stdErrSerialize; logger2._stdErrSerialize = stdErrSerialize;
logger.child = child; logger2.child = child;
if (transmit2) if (transmit2)
logger._logEvent = createLogEventShape(); logger2._logEvent = createLogEventShape();
function getLevelVal() { function getLevelVal() {
return this.level === "silent" ? Infinity : this.levels.values[this.level]; return this.level === "silent" ? Infinity : this.levels.values[this.level];
} }
@ -9226,14 +9226,14 @@ var require_browser2 = __commonJS({
throw Error("unknown level " + level2); throw Error("unknown level " + level2);
} }
this._level = level2; this._level = level2;
set(setOpts, logger, "error", "log"); set(setOpts, logger2, "error", "log");
set(setOpts, logger, "fatal", "error"); set(setOpts, logger2, "fatal", "error");
set(setOpts, logger, "warn", "error"); set(setOpts, logger2, "warn", "error");
set(setOpts, logger, "info", "log"); set(setOpts, logger2, "info", "log");
set(setOpts, logger, "debug", "log"); set(setOpts, logger2, "debug", "log");
set(setOpts, logger, "trace", "log"); set(setOpts, logger2, "trace", "log");
customLevels.forEach(function(level3) { customLevels.forEach(function(level3) {
set(setOpts, logger, level3, "log"); set(setOpts, logger2, level3, "log");
}); });
} }
function child(bindings, childOptions) { function child(bindings, childOptions) {
@ -9272,7 +9272,7 @@ var require_browser2 = __commonJS({
Child.prototype = this; Child.prototype = this;
return new Child(this); return new Child(this);
} }
return logger; return logger2;
} }
function getLevels(opts) { function getLevels(opts) {
const customLevels = opts.customLevels || {}; const customLevels = opts.customLevels || {};
@ -9310,15 +9310,15 @@ var require_browser2 = __commonJS({
}; };
pino2.stdSerializers = stdSerializers; pino2.stdSerializers = stdSerializers;
pino2.stdTimeFunctions = Object.assign({}, { nullTime, epochTime, unixTime, isoTime }); pino2.stdTimeFunctions = Object.assign({}, { nullTime, epochTime, unixTime, isoTime });
function set(opts, logger, level, fallback) { function set(opts, logger2, level, fallback) {
const proto = Object.getPrototypeOf(logger); const proto = Object.getPrototypeOf(logger2);
logger[level] = logger.levelVal > logger.levels.values[level] ? noop3 : proto[level] ? proto[level] : _console[level] || _console[fallback] || noop3; logger2[level] = logger2.levelVal > logger2.levels.values[level] ? noop3 : proto[level] ? proto[level] : _console[level] || _console[fallback] || noop3;
wrap(opts, logger, level); wrap(opts, logger2, level);
} }
function wrap(opts, logger, level) { function wrap(opts, logger2, level) {
if (!opts.transmit && logger[level] === noop3) if (!opts.transmit && logger2[level] === noop3)
return; return;
logger[level] = function(write2) { logger2[level] = function(write2) {
return function LOG() { return function LOG() {
const ts = opts.timestamp(); const ts = opts.timestamp();
const args = new Array(arguments.length); const args = new Array(arguments.length);
@ -9333,9 +9333,9 @@ var require_browser2 = __commonJS({
else else
write2.apply(proto, args); write2.apply(proto, args);
if (opts.transmit) { if (opts.transmit) {
const transmitLevel = opts.transmit.level || logger.level; const transmitLevel = opts.transmit.level || logger2.level;
const transmitValue = logger.levels.values[transmitLevel]; const transmitValue = logger2.levels.values[transmitLevel];
const methodValue = logger.levels.values[level]; const methodValue = logger2.levels.values[level];
if (methodValue < transmitValue) if (methodValue < transmitValue)
return; return;
transmit(this, { transmit(this, {
@ -9343,25 +9343,25 @@ var require_browser2 = __commonJS({
methodLevel: level, methodLevel: level,
methodValue, methodValue,
transmitLevel, transmitLevel,
transmitValue: logger.levels.values[opts.transmit.level || logger.level], transmitValue: logger2.levels.values[opts.transmit.level || logger2.level],
send: opts.transmit.send, send: opts.transmit.send,
val: logger.levelVal val: logger2.levelVal
}, args); }, args);
} }
}; };
}(logger[level]); }(logger2[level]);
} }
function asObject(logger, level, args, ts) { function asObject(logger2, level, args, ts) {
if (logger._serialize) if (logger2._serialize)
applySerializers(args, logger._serialize, logger.serializers, logger._stdErrSerialize); applySerializers(args, logger2._serialize, logger2.serializers, logger2._stdErrSerialize);
const argsCloned = args.slice(); const argsCloned = args.slice();
let msg = argsCloned[0]; let msg = argsCloned[0];
const o8 = {}; const o8 = {};
if (ts) { if (ts) {
o8.time = ts; o8.time = ts;
} }
o8.level = logger.levels.values[level]; o8.level = logger2.levels.values[level];
let lvl = (logger._childLevel | 0) + 1; let lvl = (logger2._childLevel | 0) + 1;
if (lvl < 1) if (lvl < 1)
lvl = 1; lvl = 1;
if (msg !== null && typeof msg === "object") { if (msg !== null && typeof msg === "object") {
@ -9398,27 +9398,27 @@ var require_browser2 = __commonJS({
return parent[level].apply(this, args); return parent[level].apply(this, args);
}; };
} }
function transmit(logger, opts, args) { function transmit(logger2, opts, args) {
const send = opts.send; const send = opts.send;
const ts = opts.ts; const ts = opts.ts;
const methodLevel = opts.methodLevel; const methodLevel = opts.methodLevel;
const methodValue = opts.methodValue; const methodValue = opts.methodValue;
const val = opts.val; const val = opts.val;
const bindings = logger._logEvent.bindings; const bindings = logger2._logEvent.bindings;
applySerializers( applySerializers(
args, args,
logger._serialize || Object.keys(logger.serializers), logger2._serialize || Object.keys(logger2.serializers),
logger.serializers, logger2.serializers,
logger._stdErrSerialize === void 0 ? true : logger._stdErrSerialize logger2._stdErrSerialize === void 0 ? true : logger2._stdErrSerialize
); );
logger._logEvent.ts = ts; logger2._logEvent.ts = ts;
logger._logEvent.messages = args.filter(function(arg) { logger2._logEvent.messages = args.filter(function(arg) {
return bindings.indexOf(arg) === -1; return bindings.indexOf(arg) === -1;
}); });
logger._logEvent.level.label = methodLevel; logger2._logEvent.level.label = methodLevel;
logger._logEvent.level.value = methodValue; logger2._logEvent.level.value = methodValue;
send(methodLevel, logger._logEvent, val); send(methodLevel, logger2._logEvent, val);
logger._logEvent = createLogEventShape(bindings); logger2._logEvent = createLogEventShape(bindings);
} }
function createLogEventShape(bindings) { function createLogEventShape(bindings) {
return { return {
@ -12224,7 +12224,7 @@ function v4(options, buf, offset) {
var v4_default = v4; var v4_default = v4;
// src/TabbyAgent.ts // src/TabbyAgent.ts
var import_deep_equal = __toESM(require_deep_equal()); var import_deep_equal2 = __toESM(require_deep_equal());
var import_deepmerge = __toESM(require_cjs()); var import_deepmerge = __toESM(require_cjs());
// src/generated/index.ts // src/generated/index.ts
@ -32327,6 +32327,46 @@ var CompletionCache = class {
} }
}; };
// src/postprocess.ts
init_global();
init_dirname();
init_filename();
init_buffer2();
init_process2();
var import_deep_equal = __toESM(require_deep_equal());
var logger = rootLogger.child({ component: "Postprocess" });
var removeDuplicateLines = (context) => {
return (input) => {
const suffix = context.text.slice(context.position);
const suffixLines = splitLines(suffix);
const inputLines = splitLines(input);
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
if ((0, import_deep_equal.default)(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
return input.slice(0, index);
}
}
return input;
};
};
var dropBlank = (input) => {
return isBlank(input) ? null : input;
};
var applyFilter = (filter2) => {
return async (response) => {
response.choices = (await Promise.all(
response.choices.map(async (choice) => {
choice.text = await filter2(choice.text);
return choice;
})
)).filter(Boolean);
return response;
};
};
async function postprocess(request2, response) {
return new Promise((resolve4) => resolve4(response)).then(applyFilter(removeDuplicateLines(request2))).then(applyFilter(dropBlank));
}
// src/TabbyAgent.ts // src/TabbyAgent.ts
var TabbyAgent = class extends EventEmitter { var TabbyAgent = class extends EventEmitter {
constructor() { constructor() {
@ -32338,7 +32378,7 @@ var TabbyAgent = class extends EventEmitter {
this.onConfigUpdated(); this.onConfigUpdated();
} }
onConfigUpdated() { onConfigUpdated() {
allLoggers.forEach((logger) => logger.level = this.config.logs.level); allLoggers.forEach((logger2) => logger2.level = this.config.logs.level);
this.api = new TabbyApi({ BASE: this.config.server.endpoint }); this.api = new TabbyApi({ BASE: this.config.server.endpoint });
this.ping(); this.ping();
} }
@ -32403,14 +32443,14 @@ var TabbyAgent = class extends EventEmitter {
this.updateConfig(params.config); this.updateConfig(params.config);
} }
if (params.client) { if (params.client) {
allLoggers.forEach((logger) => logger.setBindings && logger.setBindings({ client: params.client })); allLoggers.forEach((logger2) => logger2.setBindings && logger2.setBindings({ client: params.client }));
} }
this.logger.debug({ params }, "Initialized"); this.logger.debug({ params }, "Initialized");
return true; return true;
} }
updateConfig(config2) { updateConfig(config2) {
const mergedConfig = (0, import_deepmerge.default)(this.config, config2); const mergedConfig = (0, import_deepmerge.default)(this.config, config2);
if (!(0, import_deep_equal.default)(this.config, mergedConfig)) { if (!(0, import_deep_equal2.default)(this.config, mergedConfig)) {
this.config = mergedConfig; this.config = mergedConfig;
this.onConfigUpdated(); this.onConfigUpdated();
const event = { event: "configUpdated", config: this.config }; const event = { event: "configUpdated", config: this.config };
@ -32448,6 +32488,8 @@ var TabbyAgent = class extends EventEmitter {
}); });
return cancelable( return cancelable(
promise.then((response) => { promise.then((response) => {
return postprocess(request2, response);
}).then((response) => {
this.completionCache.set(request2, response); this.completionCache.set(request2, response);
return response; return response;
}), }),

File diff suppressed because one or more lines are too long

View File

@ -8,6 +8,7 @@ import { sleep, cancelable, splitLines, isBlank } from "./utils";
import { Agent, AgentEvent, AgentInitOptions, CompletionRequest, CompletionResponse, LogEventRequest } from "./Agent"; import { Agent, AgentEvent, AgentInitOptions, CompletionRequest, CompletionResponse, LogEventRequest } from "./Agent";
import { AgentConfig, defaultAgentConfig } from "./AgentConfig"; import { AgentConfig, defaultAgentConfig } from "./AgentConfig";
import { CompletionCache } from "./CompletionCache"; import { CompletionCache } from "./CompletionCache";
import { postprocess } from "./postprocess";
import { rootLogger, allLoggers } from "./logger"; import { rootLogger, allLoggers } from "./logger";
export class TabbyAgent extends EventEmitter implements Agent { export class TabbyAgent extends EventEmitter implements Agent {
@ -150,7 +151,11 @@ export class TabbyAgent extends EventEmitter implements Agent {
segments, segments,
}); });
return cancelable( return cancelable(
promise.then((response: CompletionResponse) => { promise
.then((response) => {
return postprocess(request, response);
})
.then((response) => {
this.completionCache.set(request, response); this.completionCache.set(request, response);
return response; return response;
}), }),

View File

@ -0,0 +1,51 @@
import { CompletionRequest, CompletionResponse } from "./Agent";
import deepEqual from "deep-equal";
import { isBlank, splitLines } from "./utils";
import { rootLogger } from "./logger";
type PostprocessContext = CompletionRequest;
type PostprocessFilter = (item: string) => string | null | Promise<string | null>;
const logger = rootLogger.child({ component: "Postprocess" });
const removeDuplicateLines: (context: PostprocessContext) => PostprocessFilter = (context) => {
return (input) => {
const suffix = context.text.slice(context.position);
const suffixLines = splitLines(suffix);
const inputLines = splitLines(input);
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
if (deepEqual(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
return input.slice(0, index);
}
}
return input;
};
};
const dropBlank: PostprocessFilter = (input) => {
return isBlank(input) ? null : input;
};
const applyFilter = (filter: PostprocessFilter) => {
return async (response: CompletionResponse) => {
response.choices = (
await Promise.all(
response.choices.map(async (choice) => {
choice.text = await filter(choice.text);
return choice;
})
)
).filter(Boolean);
return response;
};
};
export async function postprocess(
request: CompletionRequest,
response: CompletionResponse
): Promise<CompletionResponse> {
return new Promise((resolve) => resolve(response))
.then(applyFilter(removeDuplicateLines(request)))
.then(applyFilter(dropBlank));
}