chat-with-gpt/app/src/tokenizer/chat-history-trimmer.ts

122 lines
4.2 KiB
TypeScript
Raw Normal View History

import { OpenAIMessage } from '../types';
import * as tokenizer from '.';
export interface ChatHistoryTrimmerOptions {
maxTokens: number,
nMostRecentMessages?: number,
preserveSystemPrompt: boolean,
preserveFirstUserMessage: boolean,
}
export class ChatHistoryTrimmer {
private output: OpenAIMessage[] = [];
constructor(private messages: OpenAIMessage[],
private readonly options: ChatHistoryTrimmerOptions) {
}
private countExcessTokens() {
return Math.max(0, tokenizer.countTokensForMessages(this.output) - this.options.maxTokens);
}
public process() {
this.output = this.messages.map(m => ({ ...m }));
if (this.options.nMostRecentMessages) {
this.output = this.removeUnwantedMessagesStrategy();
}
let excessTokens = this.countExcessTokens();
if (excessTokens === 0) {
return this.output;
}
this.output = this.removeMessagesStrategy();
excessTokens = this.countExcessTokens();
if (excessTokens === 0) {
return this.output;
}
this.output = this.trimMessagesStrategy(excessTokens);
excessTokens = this.countExcessTokens();
if (excessTokens === 0) {
return this.output;
}
const last = this.messages[this.messages.length - 1];
this.output = [
tokenizer.truncateMessage(last, this.options.maxTokens),
]
return this.output;
}
private removeUnwantedMessagesStrategy() {
const systemPromptIndex = this.messages.findIndex(m => m.role === 'system');
const firstUserMessageIndex = this.messages.findIndex(m => m.role === 'user');
const keepFromIndex = this.messages.length - (this.options.nMostRecentMessages || 1);
const output: OpenAIMessage[] = [];
for (let i = 0; i < this.output.length; i++) {
if (i === systemPromptIndex && this.options.preserveSystemPrompt) {
output.push(this.output[i]);
} else if (i === firstUserMessageIndex && this.options.preserveFirstUserMessage) {
output.push(this.output[i]);
} else if (i >= keepFromIndex) {
output.push(this.output[i]);
}
}
return output;
}
private removeMessagesStrategy() {
const systemPromptIndex = this.messages.findIndex(m => m.role === 'system');
const firstUserMessageIndex = this.messages.findIndex(m => m.role === 'user');
const lastMessageIndex = this.messages.length - 1;
const output: OpenAIMessage[] = [...this.output];
for (let i = 0; i < this.output.length && tokenizer.countTokensForMessages(output) > this.options.maxTokens; i++) {
if (i == lastMessageIndex) {
continue;
}
if (i !== systemPromptIndex && !this.options.preserveSystemPrompt) {
continue;
}
if (i !== firstUserMessageIndex && this.options.preserveFirstUserMessage) {
continue;
}
output[i].content = '';
}
return output.filter(m => m.content.length > 0);
}
private trimMessagesStrategy(excessTokens: number) {
const systemPromptIndex = this.output.findIndex(m => m.role === 'system');
const firstUserMessageIndex = this.output.findIndex(m => m.role === 'user');
const lastMessageIndex = this.output.length - 1;
const output: OpenAIMessage[] = [...this.output];
const truncateLength = Math.floor(excessTokens / this.output.length);
for (let i = 0; i < this.output.length && tokenizer.countTokensForMessages(output) > this.options.maxTokens; i++) {
if (i === lastMessageIndex) {
continue;
}
if (i === systemPromptIndex && this.options.preserveSystemPrompt) {
continue;
}
if (i === firstUserMessageIndex && this.options.preserveFirstUserMessage) {
continue;
}
output[i] = tokenizer.truncateMessage(output[i], truncateLength);
}
return output.filter(m => m.content.length > 0);
}
}