feat: Add agent completion post-process. (#221)
parent
5870f8e868
commit
4ea3298bc9
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
|
@ -59,7 +59,7 @@ module.exports = __toCommonJS(src_exports);
|
||||||
var import_axios2 = __toESM(require("axios"));
|
var import_axios2 = __toESM(require("axios"));
|
||||||
var import_events = require("events");
|
var import_events = require("events");
|
||||||
var import_uuid = require("uuid");
|
var import_uuid = require("uuid");
|
||||||
var import_deep_equal = __toESM(require("deep-equal"));
|
var import_deep_equal2 = __toESM(require("deep-equal"));
|
||||||
var import_deepmerge = __toESM(require("deepmerge"));
|
var import_deepmerge = __toESM(require("deepmerge"));
|
||||||
|
|
||||||
// src/generated/core/BaseHttpRequest.ts
|
// src/generated/core/BaseHttpRequest.ts
|
||||||
|
|
@ -659,6 +659,41 @@ var CompletionCache = class {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// src/postprocess.ts
|
||||||
|
var import_deep_equal = __toESM(require("deep-equal"));
|
||||||
|
var logger = rootLogger.child({ component: "Postprocess" });
|
||||||
|
var removeDuplicateLines = (context) => {
|
||||||
|
return (input) => {
|
||||||
|
const suffix = context.text.slice(context.position);
|
||||||
|
const suffixLines = splitLines(suffix);
|
||||||
|
const inputLines = splitLines(input);
|
||||||
|
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
|
||||||
|
if ((0, import_deep_equal.default)(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
|
||||||
|
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
|
||||||
|
return input.slice(0, index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
var dropBlank = (input) => {
|
||||||
|
return isBlank(input) ? null : input;
|
||||||
|
};
|
||||||
|
var applyFilter = (filter) => {
|
||||||
|
return async (response) => {
|
||||||
|
response.choices = (await Promise.all(
|
||||||
|
response.choices.map(async (choice) => {
|
||||||
|
choice.text = await filter(choice.text);
|
||||||
|
return choice;
|
||||||
|
})
|
||||||
|
)).filter(Boolean);
|
||||||
|
return response;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
async function postprocess(request2, response) {
|
||||||
|
return new Promise((resolve2) => resolve2(response)).then(applyFilter(removeDuplicateLines(request2))).then(applyFilter(dropBlank));
|
||||||
|
}
|
||||||
|
|
||||||
// src/TabbyAgent.ts
|
// src/TabbyAgent.ts
|
||||||
var TabbyAgent = class extends import_events.EventEmitter {
|
var TabbyAgent = class extends import_events.EventEmitter {
|
||||||
constructor() {
|
constructor() {
|
||||||
|
|
@ -670,7 +705,7 @@ var TabbyAgent = class extends import_events.EventEmitter {
|
||||||
this.onConfigUpdated();
|
this.onConfigUpdated();
|
||||||
}
|
}
|
||||||
onConfigUpdated() {
|
onConfigUpdated() {
|
||||||
allLoggers.forEach((logger) => logger.level = this.config.logs.level);
|
allLoggers.forEach((logger2) => logger2.level = this.config.logs.level);
|
||||||
this.api = new TabbyApi({ BASE: this.config.server.endpoint });
|
this.api = new TabbyApi({ BASE: this.config.server.endpoint });
|
||||||
this.ping();
|
this.ping();
|
||||||
}
|
}
|
||||||
|
|
@ -735,14 +770,14 @@ var TabbyAgent = class extends import_events.EventEmitter {
|
||||||
this.updateConfig(params.config);
|
this.updateConfig(params.config);
|
||||||
}
|
}
|
||||||
if (params.client) {
|
if (params.client) {
|
||||||
allLoggers.forEach((logger) => logger.setBindings && logger.setBindings({ client: params.client }));
|
allLoggers.forEach((logger2) => logger2.setBindings && logger2.setBindings({ client: params.client }));
|
||||||
}
|
}
|
||||||
this.logger.debug({ params }, "Initialized");
|
this.logger.debug({ params }, "Initialized");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
updateConfig(config) {
|
updateConfig(config) {
|
||||||
const mergedConfig = (0, import_deepmerge.default)(this.config, config);
|
const mergedConfig = (0, import_deepmerge.default)(this.config, config);
|
||||||
if (!(0, import_deep_equal.default)(this.config, mergedConfig)) {
|
if (!(0, import_deep_equal2.default)(this.config, mergedConfig)) {
|
||||||
this.config = mergedConfig;
|
this.config = mergedConfig;
|
||||||
this.onConfigUpdated();
|
this.onConfigUpdated();
|
||||||
const event = { event: "configUpdated", config: this.config };
|
const event = { event: "configUpdated", config: this.config };
|
||||||
|
|
@ -780,6 +815,8 @@ var TabbyAgent = class extends import_events.EventEmitter {
|
||||||
});
|
});
|
||||||
return cancelable(
|
return cancelable(
|
||||||
promise.then((response) => {
|
promise.then((response) => {
|
||||||
|
return postprocess(request2, response);
|
||||||
|
}).then((response) => {
|
||||||
this.completionCache.set(request2, response);
|
this.completionCache.set(request2, response);
|
||||||
return response;
|
return response;
|
||||||
}),
|
}),
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -5950,7 +5950,7 @@ var require_deep_equal = __commonJS({
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
module.exports = function deepEqual3(a7, b5, opts) {
|
module.exports = function deepEqual4(a7, b5, opts) {
|
||||||
return internalDeepEqual(a7, b5, opts, getSideChannel());
|
return internalDeepEqual(a7, b5, opts, getSideChannel());
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
@ -9189,13 +9189,13 @@ var require_browser2 = __commonJS({
|
||||||
if (opts.enabled === false || opts.browser.disabled)
|
if (opts.enabled === false || opts.browser.disabled)
|
||||||
opts.level = "silent";
|
opts.level = "silent";
|
||||||
const level = opts.level || "info";
|
const level = opts.level || "info";
|
||||||
const logger = Object.create(proto);
|
const logger2 = Object.create(proto);
|
||||||
if (!logger.log)
|
if (!logger2.log)
|
||||||
logger.log = noop3;
|
logger2.log = noop3;
|
||||||
Object.defineProperty(logger, "levelVal", {
|
Object.defineProperty(logger2, "levelVal", {
|
||||||
get: getLevelVal
|
get: getLevelVal
|
||||||
});
|
});
|
||||||
Object.defineProperty(logger, "level", {
|
Object.defineProperty(logger2, "level", {
|
||||||
get: getLevel,
|
get: getLevel,
|
||||||
set: setLevel
|
set: setLevel
|
||||||
});
|
});
|
||||||
|
|
@ -9206,15 +9206,15 @@ var require_browser2 = __commonJS({
|
||||||
levels,
|
levels,
|
||||||
timestamp: getTimeFunction(opts)
|
timestamp: getTimeFunction(opts)
|
||||||
};
|
};
|
||||||
logger.levels = getLevels(opts);
|
logger2.levels = getLevels(opts);
|
||||||
logger.level = level;
|
logger2.level = level;
|
||||||
logger.setMaxListeners = logger.getMaxListeners = logger.emit = logger.addListener = logger.on = logger.prependListener = logger.once = logger.prependOnceListener = logger.removeListener = logger.removeAllListeners = logger.listeners = logger.listenerCount = logger.eventNames = logger.write = logger.flush = noop3;
|
logger2.setMaxListeners = logger2.getMaxListeners = logger2.emit = logger2.addListener = logger2.on = logger2.prependListener = logger2.once = logger2.prependOnceListener = logger2.removeListener = logger2.removeAllListeners = logger2.listeners = logger2.listenerCount = logger2.eventNames = logger2.write = logger2.flush = noop3;
|
||||||
logger.serializers = serializers;
|
logger2.serializers = serializers;
|
||||||
logger._serialize = serialize;
|
logger2._serialize = serialize;
|
||||||
logger._stdErrSerialize = stdErrSerialize;
|
logger2._stdErrSerialize = stdErrSerialize;
|
||||||
logger.child = child;
|
logger2.child = child;
|
||||||
if (transmit2)
|
if (transmit2)
|
||||||
logger._logEvent = createLogEventShape();
|
logger2._logEvent = createLogEventShape();
|
||||||
function getLevelVal() {
|
function getLevelVal() {
|
||||||
return this.level === "silent" ? Infinity : this.levels.values[this.level];
|
return this.level === "silent" ? Infinity : this.levels.values[this.level];
|
||||||
}
|
}
|
||||||
|
|
@ -9226,14 +9226,14 @@ var require_browser2 = __commonJS({
|
||||||
throw Error("unknown level " + level2);
|
throw Error("unknown level " + level2);
|
||||||
}
|
}
|
||||||
this._level = level2;
|
this._level = level2;
|
||||||
set(setOpts, logger, "error", "log");
|
set(setOpts, logger2, "error", "log");
|
||||||
set(setOpts, logger, "fatal", "error");
|
set(setOpts, logger2, "fatal", "error");
|
||||||
set(setOpts, logger, "warn", "error");
|
set(setOpts, logger2, "warn", "error");
|
||||||
set(setOpts, logger, "info", "log");
|
set(setOpts, logger2, "info", "log");
|
||||||
set(setOpts, logger, "debug", "log");
|
set(setOpts, logger2, "debug", "log");
|
||||||
set(setOpts, logger, "trace", "log");
|
set(setOpts, logger2, "trace", "log");
|
||||||
customLevels.forEach(function(level3) {
|
customLevels.forEach(function(level3) {
|
||||||
set(setOpts, logger, level3, "log");
|
set(setOpts, logger2, level3, "log");
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
function child(bindings, childOptions) {
|
function child(bindings, childOptions) {
|
||||||
|
|
@ -9272,7 +9272,7 @@ var require_browser2 = __commonJS({
|
||||||
Child.prototype = this;
|
Child.prototype = this;
|
||||||
return new Child(this);
|
return new Child(this);
|
||||||
}
|
}
|
||||||
return logger;
|
return logger2;
|
||||||
}
|
}
|
||||||
function getLevels(opts) {
|
function getLevels(opts) {
|
||||||
const customLevels = opts.customLevels || {};
|
const customLevels = opts.customLevels || {};
|
||||||
|
|
@ -9310,15 +9310,15 @@ var require_browser2 = __commonJS({
|
||||||
};
|
};
|
||||||
pino2.stdSerializers = stdSerializers;
|
pino2.stdSerializers = stdSerializers;
|
||||||
pino2.stdTimeFunctions = Object.assign({}, { nullTime, epochTime, unixTime, isoTime });
|
pino2.stdTimeFunctions = Object.assign({}, { nullTime, epochTime, unixTime, isoTime });
|
||||||
function set(opts, logger, level, fallback) {
|
function set(opts, logger2, level, fallback) {
|
||||||
const proto = Object.getPrototypeOf(logger);
|
const proto = Object.getPrototypeOf(logger2);
|
||||||
logger[level] = logger.levelVal > logger.levels.values[level] ? noop3 : proto[level] ? proto[level] : _console[level] || _console[fallback] || noop3;
|
logger2[level] = logger2.levelVal > logger2.levels.values[level] ? noop3 : proto[level] ? proto[level] : _console[level] || _console[fallback] || noop3;
|
||||||
wrap(opts, logger, level);
|
wrap(opts, logger2, level);
|
||||||
}
|
}
|
||||||
function wrap(opts, logger, level) {
|
function wrap(opts, logger2, level) {
|
||||||
if (!opts.transmit && logger[level] === noop3)
|
if (!opts.transmit && logger2[level] === noop3)
|
||||||
return;
|
return;
|
||||||
logger[level] = function(write2) {
|
logger2[level] = function(write2) {
|
||||||
return function LOG() {
|
return function LOG() {
|
||||||
const ts = opts.timestamp();
|
const ts = opts.timestamp();
|
||||||
const args = new Array(arguments.length);
|
const args = new Array(arguments.length);
|
||||||
|
|
@ -9333,9 +9333,9 @@ var require_browser2 = __commonJS({
|
||||||
else
|
else
|
||||||
write2.apply(proto, args);
|
write2.apply(proto, args);
|
||||||
if (opts.transmit) {
|
if (opts.transmit) {
|
||||||
const transmitLevel = opts.transmit.level || logger.level;
|
const transmitLevel = opts.transmit.level || logger2.level;
|
||||||
const transmitValue = logger.levels.values[transmitLevel];
|
const transmitValue = logger2.levels.values[transmitLevel];
|
||||||
const methodValue = logger.levels.values[level];
|
const methodValue = logger2.levels.values[level];
|
||||||
if (methodValue < transmitValue)
|
if (methodValue < transmitValue)
|
||||||
return;
|
return;
|
||||||
transmit(this, {
|
transmit(this, {
|
||||||
|
|
@ -9343,25 +9343,25 @@ var require_browser2 = __commonJS({
|
||||||
methodLevel: level,
|
methodLevel: level,
|
||||||
methodValue,
|
methodValue,
|
||||||
transmitLevel,
|
transmitLevel,
|
||||||
transmitValue: logger.levels.values[opts.transmit.level || logger.level],
|
transmitValue: logger2.levels.values[opts.transmit.level || logger2.level],
|
||||||
send: opts.transmit.send,
|
send: opts.transmit.send,
|
||||||
val: logger.levelVal
|
val: logger2.levelVal
|
||||||
}, args);
|
}, args);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}(logger[level]);
|
}(logger2[level]);
|
||||||
}
|
}
|
||||||
function asObject(logger, level, args, ts) {
|
function asObject(logger2, level, args, ts) {
|
||||||
if (logger._serialize)
|
if (logger2._serialize)
|
||||||
applySerializers(args, logger._serialize, logger.serializers, logger._stdErrSerialize);
|
applySerializers(args, logger2._serialize, logger2.serializers, logger2._stdErrSerialize);
|
||||||
const argsCloned = args.slice();
|
const argsCloned = args.slice();
|
||||||
let msg = argsCloned[0];
|
let msg = argsCloned[0];
|
||||||
const o8 = {};
|
const o8 = {};
|
||||||
if (ts) {
|
if (ts) {
|
||||||
o8.time = ts;
|
o8.time = ts;
|
||||||
}
|
}
|
||||||
o8.level = logger.levels.values[level];
|
o8.level = logger2.levels.values[level];
|
||||||
let lvl = (logger._childLevel | 0) + 1;
|
let lvl = (logger2._childLevel | 0) + 1;
|
||||||
if (lvl < 1)
|
if (lvl < 1)
|
||||||
lvl = 1;
|
lvl = 1;
|
||||||
if (msg !== null && typeof msg === "object") {
|
if (msg !== null && typeof msg === "object") {
|
||||||
|
|
@ -9398,27 +9398,27 @@ var require_browser2 = __commonJS({
|
||||||
return parent[level].apply(this, args);
|
return parent[level].apply(this, args);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
function transmit(logger, opts, args) {
|
function transmit(logger2, opts, args) {
|
||||||
const send = opts.send;
|
const send = opts.send;
|
||||||
const ts = opts.ts;
|
const ts = opts.ts;
|
||||||
const methodLevel = opts.methodLevel;
|
const methodLevel = opts.methodLevel;
|
||||||
const methodValue = opts.methodValue;
|
const methodValue = opts.methodValue;
|
||||||
const val = opts.val;
|
const val = opts.val;
|
||||||
const bindings = logger._logEvent.bindings;
|
const bindings = logger2._logEvent.bindings;
|
||||||
applySerializers(
|
applySerializers(
|
||||||
args,
|
args,
|
||||||
logger._serialize || Object.keys(logger.serializers),
|
logger2._serialize || Object.keys(logger2.serializers),
|
||||||
logger.serializers,
|
logger2.serializers,
|
||||||
logger._stdErrSerialize === void 0 ? true : logger._stdErrSerialize
|
logger2._stdErrSerialize === void 0 ? true : logger2._stdErrSerialize
|
||||||
);
|
);
|
||||||
logger._logEvent.ts = ts;
|
logger2._logEvent.ts = ts;
|
||||||
logger._logEvent.messages = args.filter(function(arg) {
|
logger2._logEvent.messages = args.filter(function(arg) {
|
||||||
return bindings.indexOf(arg) === -1;
|
return bindings.indexOf(arg) === -1;
|
||||||
});
|
});
|
||||||
logger._logEvent.level.label = methodLevel;
|
logger2._logEvent.level.label = methodLevel;
|
||||||
logger._logEvent.level.value = methodValue;
|
logger2._logEvent.level.value = methodValue;
|
||||||
send(methodLevel, logger._logEvent, val);
|
send(methodLevel, logger2._logEvent, val);
|
||||||
logger._logEvent = createLogEventShape(bindings);
|
logger2._logEvent = createLogEventShape(bindings);
|
||||||
}
|
}
|
||||||
function createLogEventShape(bindings) {
|
function createLogEventShape(bindings) {
|
||||||
return {
|
return {
|
||||||
|
|
@ -12224,7 +12224,7 @@ function v4(options, buf, offset) {
|
||||||
var v4_default = v4;
|
var v4_default = v4;
|
||||||
|
|
||||||
// src/TabbyAgent.ts
|
// src/TabbyAgent.ts
|
||||||
var import_deep_equal = __toESM(require_deep_equal());
|
var import_deep_equal2 = __toESM(require_deep_equal());
|
||||||
var import_deepmerge = __toESM(require_cjs());
|
var import_deepmerge = __toESM(require_cjs());
|
||||||
|
|
||||||
// src/generated/index.ts
|
// src/generated/index.ts
|
||||||
|
|
@ -32327,6 +32327,46 @@ var CompletionCache = class {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// src/postprocess.ts
|
||||||
|
init_global();
|
||||||
|
init_dirname();
|
||||||
|
init_filename();
|
||||||
|
init_buffer2();
|
||||||
|
init_process2();
|
||||||
|
var import_deep_equal = __toESM(require_deep_equal());
|
||||||
|
var logger = rootLogger.child({ component: "Postprocess" });
|
||||||
|
var removeDuplicateLines = (context) => {
|
||||||
|
return (input) => {
|
||||||
|
const suffix = context.text.slice(context.position);
|
||||||
|
const suffixLines = splitLines(suffix);
|
||||||
|
const inputLines = splitLines(input);
|
||||||
|
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
|
||||||
|
if ((0, import_deep_equal.default)(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
|
||||||
|
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
|
||||||
|
return input.slice(0, index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
var dropBlank = (input) => {
|
||||||
|
return isBlank(input) ? null : input;
|
||||||
|
};
|
||||||
|
var applyFilter = (filter2) => {
|
||||||
|
return async (response) => {
|
||||||
|
response.choices = (await Promise.all(
|
||||||
|
response.choices.map(async (choice) => {
|
||||||
|
choice.text = await filter2(choice.text);
|
||||||
|
return choice;
|
||||||
|
})
|
||||||
|
)).filter(Boolean);
|
||||||
|
return response;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
async function postprocess(request2, response) {
|
||||||
|
return new Promise((resolve4) => resolve4(response)).then(applyFilter(removeDuplicateLines(request2))).then(applyFilter(dropBlank));
|
||||||
|
}
|
||||||
|
|
||||||
// src/TabbyAgent.ts
|
// src/TabbyAgent.ts
|
||||||
var TabbyAgent = class extends EventEmitter {
|
var TabbyAgent = class extends EventEmitter {
|
||||||
constructor() {
|
constructor() {
|
||||||
|
|
@ -32338,7 +32378,7 @@ var TabbyAgent = class extends EventEmitter {
|
||||||
this.onConfigUpdated();
|
this.onConfigUpdated();
|
||||||
}
|
}
|
||||||
onConfigUpdated() {
|
onConfigUpdated() {
|
||||||
allLoggers.forEach((logger) => logger.level = this.config.logs.level);
|
allLoggers.forEach((logger2) => logger2.level = this.config.logs.level);
|
||||||
this.api = new TabbyApi({ BASE: this.config.server.endpoint });
|
this.api = new TabbyApi({ BASE: this.config.server.endpoint });
|
||||||
this.ping();
|
this.ping();
|
||||||
}
|
}
|
||||||
|
|
@ -32403,14 +32443,14 @@ var TabbyAgent = class extends EventEmitter {
|
||||||
this.updateConfig(params.config);
|
this.updateConfig(params.config);
|
||||||
}
|
}
|
||||||
if (params.client) {
|
if (params.client) {
|
||||||
allLoggers.forEach((logger) => logger.setBindings && logger.setBindings({ client: params.client }));
|
allLoggers.forEach((logger2) => logger2.setBindings && logger2.setBindings({ client: params.client }));
|
||||||
}
|
}
|
||||||
this.logger.debug({ params }, "Initialized");
|
this.logger.debug({ params }, "Initialized");
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
updateConfig(config2) {
|
updateConfig(config2) {
|
||||||
const mergedConfig = (0, import_deepmerge.default)(this.config, config2);
|
const mergedConfig = (0, import_deepmerge.default)(this.config, config2);
|
||||||
if (!(0, import_deep_equal.default)(this.config, mergedConfig)) {
|
if (!(0, import_deep_equal2.default)(this.config, mergedConfig)) {
|
||||||
this.config = mergedConfig;
|
this.config = mergedConfig;
|
||||||
this.onConfigUpdated();
|
this.onConfigUpdated();
|
||||||
const event = { event: "configUpdated", config: this.config };
|
const event = { event: "configUpdated", config: this.config };
|
||||||
|
|
@ -32448,6 +32488,8 @@ var TabbyAgent = class extends EventEmitter {
|
||||||
});
|
});
|
||||||
return cancelable(
|
return cancelable(
|
||||||
promise.then((response) => {
|
promise.then((response) => {
|
||||||
|
return postprocess(request2, response);
|
||||||
|
}).then((response) => {
|
||||||
this.completionCache.set(request2, response);
|
this.completionCache.set(request2, response);
|
||||||
return response;
|
return response;
|
||||||
}),
|
}),
|
||||||
|
|
|
||||||
File diff suppressed because one or more lines are too long
|
|
@ -8,6 +8,7 @@ import { sleep, cancelable, splitLines, isBlank } from "./utils";
|
||||||
import { Agent, AgentEvent, AgentInitOptions, CompletionRequest, CompletionResponse, LogEventRequest } from "./Agent";
|
import { Agent, AgentEvent, AgentInitOptions, CompletionRequest, CompletionResponse, LogEventRequest } from "./Agent";
|
||||||
import { AgentConfig, defaultAgentConfig } from "./AgentConfig";
|
import { AgentConfig, defaultAgentConfig } from "./AgentConfig";
|
||||||
import { CompletionCache } from "./CompletionCache";
|
import { CompletionCache } from "./CompletionCache";
|
||||||
|
import { postprocess } from "./postprocess";
|
||||||
import { rootLogger, allLoggers } from "./logger";
|
import { rootLogger, allLoggers } from "./logger";
|
||||||
|
|
||||||
export class TabbyAgent extends EventEmitter implements Agent {
|
export class TabbyAgent extends EventEmitter implements Agent {
|
||||||
|
|
@ -150,10 +151,14 @@ export class TabbyAgent extends EventEmitter implements Agent {
|
||||||
segments,
|
segments,
|
||||||
});
|
});
|
||||||
return cancelable(
|
return cancelable(
|
||||||
promise.then((response: CompletionResponse) => {
|
promise
|
||||||
this.completionCache.set(request, response);
|
.then((response) => {
|
||||||
return response;
|
return postprocess(request, response);
|
||||||
}),
|
})
|
||||||
|
.then((response) => {
|
||||||
|
this.completionCache.set(request, response);
|
||||||
|
return response;
|
||||||
|
}),
|
||||||
() => {
|
() => {
|
||||||
promise.cancel();
|
promise.cancel();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
import { CompletionRequest, CompletionResponse } from "./Agent";
|
||||||
|
import deepEqual from "deep-equal";
|
||||||
|
import { isBlank, splitLines } from "./utils";
|
||||||
|
import { rootLogger } from "./logger";
|
||||||
|
|
||||||
|
type PostprocessContext = CompletionRequest;
|
||||||
|
type PostprocessFilter = (item: string) => string | null | Promise<string | null>;
|
||||||
|
|
||||||
|
const logger = rootLogger.child({ component: "Postprocess" });
|
||||||
|
|
||||||
|
const removeDuplicateLines: (context: PostprocessContext) => PostprocessFilter = (context) => {
|
||||||
|
return (input) => {
|
||||||
|
const suffix = context.text.slice(context.position);
|
||||||
|
const suffixLines = splitLines(suffix);
|
||||||
|
const inputLines = splitLines(input);
|
||||||
|
for (let index = Math.max(0, inputLines.length - suffixLines.length); index < inputLines.length; index++) {
|
||||||
|
if (deepEqual(inputLines.slice(index), suffixLines.slice(0, input.length - index))) {
|
||||||
|
logger.debug({ input, suffix, duplicateAt: index }, "Remove duplicate lines");
|
||||||
|
return input.slice(0, index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return input;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
const dropBlank: PostprocessFilter = (input) => {
|
||||||
|
return isBlank(input) ? null : input;
|
||||||
|
};
|
||||||
|
|
||||||
|
const applyFilter = (filter: PostprocessFilter) => {
|
||||||
|
return async (response: CompletionResponse) => {
|
||||||
|
response.choices = (
|
||||||
|
await Promise.all(
|
||||||
|
response.choices.map(async (choice) => {
|
||||||
|
choice.text = await filter(choice.text);
|
||||||
|
return choice;
|
||||||
|
})
|
||||||
|
)
|
||||||
|
).filter(Boolean);
|
||||||
|
return response;
|
||||||
|
};
|
||||||
|
};
|
||||||
|
|
||||||
|
export async function postprocess(
|
||||||
|
request: CompletionRequest,
|
||||||
|
response: CompletionResponse
|
||||||
|
): Promise<CompletionResponse> {
|
||||||
|
return new Promise((resolve) => resolve(response))
|
||||||
|
.then(applyFilter(removeDuplicateLines(request)))
|
||||||
|
.then(applyFilter(dropBlank));
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue