122 lines
4.2 KiB
TypeScript
122 lines
4.2 KiB
TypeScript
|
import { OpenAIMessage } from '../types';
|
||
|
import * as tokenizer from '.';
|
||
|
|
||
|
export interface ChatHistoryTrimmerOptions {
|
||
|
maxTokens: number,
|
||
|
nMostRecentMessages?: number,
|
||
|
preserveSystemPrompt: boolean,
|
||
|
preserveFirstUserMessage: boolean,
|
||
|
}
|
||
|
|
||
|
export class ChatHistoryTrimmer {
|
||
|
private output: OpenAIMessage[] = [];
|
||
|
|
||
|
constructor(private messages: OpenAIMessage[],
|
||
|
private readonly options: ChatHistoryTrimmerOptions) {
|
||
|
}
|
||
|
|
||
|
private countExcessTokens() {
|
||
|
return Math.max(0, tokenizer.countTokensForMessages(this.output) - this.options.maxTokens);
|
||
|
}
|
||
|
|
||
|
public process() {
|
||
|
this.output = this.messages.map(m => ({ ...m }));
|
||
|
|
||
|
if (this.options.nMostRecentMessages) {
|
||
|
this.output = this.removeUnwantedMessagesStrategy();
|
||
|
}
|
||
|
|
||
|
let excessTokens = this.countExcessTokens();
|
||
|
|
||
|
if (excessTokens === 0) {
|
||
|
return this.output;
|
||
|
}
|
||
|
|
||
|
this.output = this.removeMessagesStrategy();
|
||
|
|
||
|
excessTokens = this.countExcessTokens();
|
||
|
if (excessTokens === 0) {
|
||
|
return this.output;
|
||
|
}
|
||
|
|
||
|
this.output = this.trimMessagesStrategy(excessTokens);
|
||
|
|
||
|
excessTokens = this.countExcessTokens();
|
||
|
if (excessTokens === 0) {
|
||
|
return this.output;
|
||
|
}
|
||
|
|
||
|
const last = this.messages[this.messages.length - 1];
|
||
|
this.output = [
|
||
|
tokenizer.truncateMessage(last, this.options.maxTokens),
|
||
|
]
|
||
|
|
||
|
return this.output;
|
||
|
}
|
||
|
|
||
|
private removeUnwantedMessagesStrategy() {
|
||
|
const systemPromptIndex = this.messages.findIndex(m => m.role === 'system');
|
||
|
const firstUserMessageIndex = this.messages.findIndex(m => m.role === 'user');
|
||
|
const keepFromIndex = this.messages.length - (this.options.nMostRecentMessages || 1);
|
||
|
|
||
|
const output: OpenAIMessage[] = [];
|
||
|
|
||
|
for (let i = 0; i < this.output.length; i++) {
|
||
|
if (i === systemPromptIndex && this.options.preserveSystemPrompt) {
|
||
|
output.push(this.output[i]);
|
||
|
} else if (i === firstUserMessageIndex && this.options.preserveFirstUserMessage) {
|
||
|
output.push(this.output[i]);
|
||
|
} else if (i >= keepFromIndex) {
|
||
|
output.push(this.output[i]);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return output;
|
||
|
}
|
||
|
|
||
|
private removeMessagesStrategy() {
|
||
|
const systemPromptIndex = this.messages.findIndex(m => m.role === 'system');
|
||
|
const firstUserMessageIndex = this.messages.findIndex(m => m.role === 'user');
|
||
|
const lastMessageIndex = this.messages.length - 1;
|
||
|
|
||
|
const output: OpenAIMessage[] = [...this.output];
|
||
|
|
||
|
for (let i = 0; i < this.output.length && tokenizer.countTokensForMessages(output) > this.options.maxTokens; i++) {
|
||
|
if (i == lastMessageIndex) {
|
||
|
continue;
|
||
|
}
|
||
|
if (i !== systemPromptIndex && !this.options.preserveSystemPrompt) {
|
||
|
continue;
|
||
|
}
|
||
|
if (i !== firstUserMessageIndex && this.options.preserveFirstUserMessage) {
|
||
|
continue;
|
||
|
}
|
||
|
output[i].content = '';
|
||
|
}
|
||
|
|
||
|
return output.filter(m => m.content.length > 0);
|
||
|
}
|
||
|
|
||
|
private trimMessagesStrategy(excessTokens: number) {
|
||
|
const systemPromptIndex = this.output.findIndex(m => m.role === 'system');
|
||
|
const firstUserMessageIndex = this.output.findIndex(m => m.role === 'user');
|
||
|
const lastMessageIndex = this.output.length - 1;
|
||
|
|
||
|
const output: OpenAIMessage[] = [...this.output];
|
||
|
const truncateLength = Math.floor(excessTokens / this.output.length);
|
||
|
|
||
|
for (let i = 0; i < this.output.length && tokenizer.countTokensForMessages(output) > this.options.maxTokens; i++) {
|
||
|
if (i === lastMessageIndex) {
|
||
|
continue;
|
||
|
}
|
||
|
if (i === systemPromptIndex && this.options.preserveSystemPrompt) {
|
||
|
continue;
|
||
|
}
|
||
|
if (i === firstUserMessageIndex && this.options.preserveFirstUserMessage) {
|
||
|
continue;
|
||
|
}
|
||
|
output[i] = tokenizer.truncateMessage(output[i], truncateLength);
|
||
|
}
|
||
|
return output.filter(m => m.content.length > 0);
|
||
|
}
|
||
|
}
|