This commit is contained in:
Cogent Apps
2023-04-15 10:30:02 +00:00
parent 943bca2f4d
commit eb58d900b5
118 changed files with 5785 additions and 2471 deletions

276
app/src/core/backend.ts Normal file
View File

@@ -0,0 +1,276 @@
import EventEmitter from 'events';
import * as Y from 'yjs';
import { encode, decode } from '@msgpack/msgpack';
import { MessageTree } from './chat/message-tree';
import { Chat } from './chat/types';
import { AsyncLoop } from "./utils/async-loop";
import { ChatManager } from '.';
import { getRateLimitResetTimeFromResponse } from './utils';
import { importChat } from './chat/chat-persistance';
const endpoint = '/chatapi';
export let backend: {
current?: Backend | null
} = {};
export interface User {
id: string;
email?: string;
name?: string;
avatar?: string;
services?: string[];
}
export class Backend extends EventEmitter {
public user: User | null = null;
private checkedSession = false;
private sessionInterval = new AsyncLoop(() => this.getSession(), 1000 * 30);
private syncInterval = new AsyncLoop(() => this.sync(), 1000 * 5);
private pendingYUpdate: Uint8Array | null = null;
private lastFullSyncAt = 0;
private legacySync = false;
private rateLimitedUntil = 0;
public constructor(private context: ChatManager) {
super();
if ((window as any).AUTH_PROVIDER) {
backend.current = this;
this.sessionInterval.start();
this.syncInterval.start();
}
}
public isSynced() {
return (this.checkedSession && !this.isAuthenticated) || this.lastFullSyncAt > 0;
}
public async getSession() {
if (Date.now() < this.rateLimitedUntil) {
console.log(`Waiting another ${this.rateLimitedUntil - Date.now()}ms to check session due to rate limiting.`);
return;
}
const wasAuthenticated = this.isAuthenticated;
const session = await this.get(endpoint + '/session');
if (session?.authProvider) {
(window as any).AUTH_PROVIDER = session.authProvider;
}
if (session?.authenticated) {
this.user = {
id: session.userID,
email: session.email,
name: session.name,
avatar: session.picture,
services: session.services,
};
} else {
this.user = null;
}
this.checkedSession = true;
if (wasAuthenticated !== this.isAuthenticated) {
this.emit('authenticated', this.isAuthenticated);
this.lastFullSyncAt = 0;
}
}
public async sync() {
if (!this.isAuthenticated) {
return;
}
if (Date.now() < this.rateLimitedUntil) {
console.log(`Waiting another ${this.rateLimitedUntil - Date.now()}ms before syncing due to rate limiting.`);
return;
}
const encoding = await import('lib0/encoding');
const decoding = await import('lib0/decoding');
const syncProtocol = await import('y-protocols/sync');
const sinceLastFullSync = Date.now() - this.lastFullSyncAt;
const pendingYUpdate = this.pendingYUpdate;
if (pendingYUpdate && pendingYUpdate.length > 4) {
this.pendingYUpdate = null;
const encoder = encoding.createEncoder();
syncProtocol.writeUpdate(encoder, pendingYUpdate);
const response = await fetch(endpoint + '/y-sync', {
method: 'POST',
headers: {
'Content-Type': 'application/octet-stream'
},
body: encoding.toUint8Array(encoder),
});
if (response.status === 429) {
this.rateLimitedUntil = getRateLimitResetTimeFromResponse(response);
}
} else if (sinceLastFullSync > 1000 * 60 * 1) {
this.lastFullSyncAt = Date.now();
const encoder = encoding.createEncoder();
syncProtocol.writeSyncStep1(encoder, this.context.doc.root);
const queue: Uint8Array[] = [
encoding.toUint8Array(encoder),
];
for (let i = 0; i < 4; i++) {
if (!queue.length) {
break;
}
const buffer = queue.shift()!;
const response = await fetch(endpoint + '/y-sync', {
method: 'POST',
headers: {
'Content-Type': 'application/octet-stream'
},
body: buffer,
});
if (!response.ok) {
this.rateLimitedUntil = getRateLimitResetTimeFromResponse(response);
throw new Error(response.statusText);
}
const responseBuffer = await response.arrayBuffer();
const responseChunks = decode(responseBuffer) as Uint8Array[];
for (const chunk of responseChunks) {
if (!chunk.byteLength) {
continue;
}
const encoder = encoding.createEncoder();
const decoder = decoding.createDecoder(chunk);
const messageType = decoding.readVarUint(decoder);
decoder.pos = 0;
syncProtocol.readSyncMessage(decoder, encoder, this.context.doc.root, 'sync');
if (encoding.length(encoder)) {
queue.push(encoding.toUint8Array(encoder));
}
}
}
this.context.emit('update');
}
if (!this.legacySync) {
this.legacySync = true;
const chats = await this.get(endpoint + '/legacy-sync');
this.context.doc.transact(() => {
for (const chat of chats) {
try {
importChat(this.context.doc, chat);
} catch (e) {
console.error(e);
}
}
});
}
}
public receiveYUpdate(update: Uint8Array) {
if (!this.pendingYUpdate) {
this.pendingYUpdate = update;
} else {
this.pendingYUpdate = Y.mergeUpdates([this.pendingYUpdate, update]);
}
}
async signIn() {
window.location.href = endpoint + '/login';
}
get isAuthenticated() {
return this.user !== null;
}
async logout() {
window.location.href = endpoint + '/logout';
}
async shareChat(chat: Chat): Promise<string | null> {
try {
const { id } = await this.post(endpoint + '/share', {
...chat,
messages: chat.messages.serialize(),
});
if (typeof id === 'string') {
return id;
}
} catch (e) {
console.error(e);
}
return null;
}
async getSharedChat(id: string): Promise<Chat | null> {
const format = process.env.REACT_APP_SHARE_URL || (endpoint + '/share/:id');
const url = format.replace(':id', id);
try {
const chat = await this.get(url);
if (chat?.messages?.length) {
chat.messages = new MessageTree(chat.messages);
return chat;
}
} catch (e) {
console.error(e);
}
return null;
}
async deleteChat(id: string) {
if (!this.isAuthenticated) {
return;
}
return this.post(endpoint + '/delete', { id });
}
async get(url: string) {
const response = await fetch(url);
if (response.status === 429) {
this.rateLimitedUntil = getRateLimitResetTimeFromResponse(response);
}
if (!response.ok) {
throw new Error(response.statusText);
}
return response.json();
}
async post(url: string, data: any) {
const response = await fetch(url, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify(data),
});
if (response.status === 429) {
this.rateLimitedUntil = getRateLimitResetTimeFromResponse(response);
}
if (!response.ok) {
throw new Error(response.statusText);
}
return response.json();
}
}

View File

@@ -0,0 +1,71 @@
import * as idb from '../utils/idb';
import * as Y from 'yjs';
import { MessageTree } from './message-tree';
import { Chat } from './types';
import { YChatDoc } from './y-chat';
export async function loadFromPreviousVersion(doc: YChatDoc) {
const serialized = await idb.get('chats');
if (serialized) {
for (const chat of serialized) {
try {
if (chat.deleted) {
continue;
}
if (doc.has(chat.id)) {
continue;
}
const messages = new MessageTree();
for (const m of chat.messages) {
messages.addMessage(m);
}
chat.messages = messages;
importChat(doc, chat);
} catch (e) {
console.error(e);
}
}
}
}
export function importChat(doc: YChatDoc, chat: Chat) {
const ychat = doc.getYChat(chat.id, true);
if (ychat.deleted) {
return;
}
if (chat.metadata) {
for (const key of Object.keys(chat.metadata)) {
if (!ychat.importedMetadata.has(key)) {
ychat.importedMetadata.set(key, chat.metadata[key]);
}
}
} else if (chat.title) {
if (!ychat.importedMetadata.has('title')) {
ychat.importedMetadata.set('title', chat.title);
}
}
if (chat.pluginOptions) {
for (const key of Object.keys(chat.pluginOptions)) {
const [pluginID, option] = key.split('.', 2);
if (!ychat.pluginOptions.has(key)) {
ychat.setOption(pluginID, option, chat.pluginOptions[key]);
}
}
}
const messages = chat.messages instanceof MessageTree ? chat.messages.serialize() : chat.messages;
for (const message of messages) {
if (ychat.messages.has(message.id)) {
continue;
}
ychat.messages.set(message.id, message);
ychat.content.set(message.id, new Y.Text(message.content || ''));
if (message.done) {
ychat.done.set(message.id, message.done);
}
}
}

View File

@@ -0,0 +1,185 @@
import EventEmitter from "events";
import { createChatCompletion, createStreamingChatCompletion } from "./openai";
import { PluginContext } from "../plugins/plugin-context";
import { pluginRunner } from "../plugins/plugin-runner";
import { Chat, Message, OpenAIMessage, Parameters, getOpenAIMessageFromMessage } from "./types";
import { EventEmitterAsyncIterator } from "../utils/event-emitter-async-iterator";
import { YChat } from "./y-chat";
import { OptionsManager } from "../options";
export class ReplyRequest extends EventEmitter {
private mutatedMessages: OpenAIMessage[];
private mutatedParameters: Parameters;
private lastChunkReceivedAt: number = 0;
private timer: any;
private done: boolean = false;
private content = '';
private cancelSSE: any;
constructor(private chat: Chat,
private yChat: YChat,
private messages: Message[],
private replyID: string,
private requestedParameters: Parameters,
private pluginOptions: OptionsManager) {
super();
this.mutatedMessages = [...messages];
this.mutatedMessages = messages.map(m => getOpenAIMessageFromMessage(m));
this.mutatedParameters = { ...requestedParameters };
delete this.mutatedParameters.apiKey;
}
pluginContext = (pluginID: string) => ({
getOptions: () => {
return this.pluginOptions.getAllOptions(pluginID, this.chat.id);
},
getCurrentChat: () => {
return this.chat;
},
createChatCompletion: async (messages: OpenAIMessage[], _parameters: Parameters) => {
return await createChatCompletion(messages, {
..._parameters,
apiKey: this.requestedParameters.apiKey,
});
},
setChatTitle: async (title: string) => {
this.yChat.title = title;
},
} as PluginContext);
private scheduleTimeout() {
this.lastChunkReceivedAt = Date.now();
clearInterval(this.timer);
this.timer = setInterval(() => {
const sinceLastChunk = Date.now() - this.lastChunkReceivedAt;
if (sinceLastChunk > 30000 && !this.done) {
this.onError('no response from OpenAI in the last 30 seconds');
}
}, 2000);
}
public async execute() {
try {
this.scheduleTimeout();
await pluginRunner("preprocess-model-input", this.pluginContext, async plugin => {
const output = await plugin.preprocessModelInput(this.mutatedMessages, this.mutatedParameters);
this.mutatedMessages = output.messages;
this.mutatedParameters = output.parameters;
this.lastChunkReceivedAt = Date.now();
});
const { emitter, cancel } = await createStreamingChatCompletion(this.mutatedMessages, {
...this.mutatedParameters,
apiKey: this.requestedParameters.apiKey,
});
this.cancelSSE = cancel;
const eventIterator = new EventEmitterAsyncIterator<string>(emitter, ["data", "done", "error"]);
for await (const event of eventIterator) {
const { eventName, value } = event;
switch (eventName) {
case 'data':
await this.onData(value);
break;
case 'done':
await this.onDone();
break;
case 'error':
if (!this.content || !this.done) {
await this.onError(value);
}
break;
}
}
} catch (e: any) {
console.error(e);
this.onError(e.message);
}
}
public async onData(value: any) {
if (this.done) {
return;
}
this.lastChunkReceivedAt = Date.now();
this.content = value;
await pluginRunner("postprocess-model-output", this.pluginContext, async plugin => {
const output = await plugin.postprocessModelOutput({
role: 'assistant',
content: this.content,
}, this.mutatedMessages, this.mutatedParameters, false);
this.content = output.content;
});
this.yChat.setPendingMessageContent(this.replyID, this.content);
}
public async onDone() {
if (this.done) {
return;
}
clearInterval(this.timer);
this.lastChunkReceivedAt = Date.now();
this.done = true;
this.emit('done');
this.yChat.onMessageDone(this.replyID);
await pluginRunner("postprocess-model-output", this.pluginContext, async plugin => {
const output = await plugin.postprocessModelOutput({
role: 'assistant',
content: this.content,
}, this.mutatedMessages, this.mutatedParameters, true);
this.content = output.content;
});
this.yChat.setMessageContent(this.replyID, this.content);
}
public async onError(error: string) {
if (this.done) {
return;
}
this.done = true;
this.emit('done');
clearInterval(this.timer);
this.cancelSSE?.();
this.content += `\n\nI'm sorry, I'm having trouble connecting to OpenAI (${error || 'no response from the API'}). Please make sure you've entered your OpenAI API key correctly and try again.`;
this.content = this.content.trim();
this.yChat.setMessageContent(this.replyID, this.content);
this.yChat.onMessageDone(this.replyID);
}
public onCancel() {
clearInterval(this.timer);
this.done = true;
this.yChat.onMessageDone(this.replyID);
this.cancelSSE?.();
this.emit('done');
}
// private setMessageContent(content: string) {
// const text = this.yChat.content.get(this.replyID);
// if (text && text.toString() !== content) {
// text?.delete(0, text.length);
// text?.insert(0, content);
// }
// }
}

View File

@@ -0,0 +1,197 @@
import { Message } from "./types";
/**
* MessageNode interface that extends the Message type and includes parent and replies properties.
* This allows creating a tree structure from messages.
*/
export interface MessageNode extends Message {
parent: MessageNode | null;
replies: Set<MessageNode>;
}
/**
* Function to create a new MessageNode from a given message.
*
* @param {Message} message - The message to be converted to a MessageNode.
* @returns {MessageNode} - The newly created MessageNode.
*/
export function createMessageNode(message: Message): MessageNode {
return {
...message,
parent: null,
replies: new Set(),
};
}
/**
* MessageTree class for representing and managing a tree structure of messages.
* The tree is made up of MessageNode objects, which extend the `Message` type and can have parent and replies relationships.
* The purpose of the tree structure is to represent a hierarchy of messages, where one message might have multiple
* replies, and each reply has a parent message that it is replying to.
*/
export class MessageTree {
public messageNodes: Map<string, MessageNode> = new Map(); // TODO make private
constructor(messages: (Message | MessageNode)[] = []) {
this.addMessages(messages);
}
/**
* Getter method for retrieving root messages (messages without a parent) in the tree.
* @returns {MessageNode[]} - An array of root messages.
*/
public get roots(): MessageNode[] {
return Array.from(this.messageNodes.values())
.filter((messageNode) => messageNode.parent === null);
}
/**
* Getter method for retrieving leaf messages (messages without any replies) in the tree.
* @returns {MessageNode[]} - An array of leaf messages.
*/
public get leafs(): MessageNode[] {
return Array.from(this.messageNodes.values())
.filter((messageNode) => messageNode.replies.size === 0);
}
/**
* Getter method for retrieving the first message in the most recent message chain.
* @returns {MessageNode | null} - The first message in the most recent message chain, or null if the tree is empty.
*/
public get first(): MessageNode | null {
const leaf = this.mostRecentLeaf();
let first: MessageNode | null = leaf;
while (first?.parent) {
first = first.parent;
}
return first;
}
/**
* Method to get a message node from the tree by its ID.
* @param {string} id - The ID of the message node to retrieve.
* @returns {MessageNode | null} - The message node with the given ID, or null if it does not exist in the tree.
*/
public get(id: string): MessageNode | null {
return this.messageNodes.get(id) || null;
}
/**
* Method to add a message to the tree. If a message with the same ID already exists in the tree, this method does nothing.
* @param {Message} message - The message to add to the tree.
*/
public addMessage(inputMessage: Message, content: string | undefined = '', done: boolean | undefined = false): void {
const message = {
...inputMessage,
content: content || inputMessage.content || '',
done: typeof done === 'boolean' ? done : inputMessage.done,
};
if (this.messageNodes.get(message.id)?.content) {
return;
}
const messageNode = createMessageNode(message);
this.messageNodes.set(messageNode.id, messageNode);
if (messageNode.parentID) {
let parent = this.messageNodes.get(messageNode.parentID);
if (!parent) {
parent = createMessageNode({
id: messageNode.parentID,
} as Message);
this.messageNodes.set(parent.id, parent);
}
parent.replies.add(messageNode);
messageNode.parent = parent;
}
for (const other of Array.from(this.messageNodes.values())) {
if (other.parentID === messageNode.id) {
messageNode.replies.add(other);
other.parent = messageNode;
}
}
}
/**
* Method to add multiple messages to the tree.
* @param {Message[]} messages - An array of messages to add to the tree.
*/
public addMessages(messages: Message[]): void {
for (const message of messages) {
try {
this.addMessage(message);
} catch (e) {
console.error(`Error adding message with id: ${message.id}`, e);
}
}
}
/**
* Method to update the content, timestamp, and done status of an existing message in the tree.
* @param {Message} message - The updated message.
*/
public updateMessage(message: Message): void {
const messageNode = this.messageNodes.get(message.id);
if (!messageNode) {
return;
}
messageNode.content = message.content;
messageNode.timestamp = message.timestamp;
messageNode.done = message.done;
}
/**
* Method to get the message chain leading to a specific message by its ID.
* @param {string} messageID - The ID of the target message.
* @returns {MessageNode[]} - An array of message nodes in the chain leading to the target message.
*/
public getMessageChainTo(messageID: string): MessageNode[] {
const message = this.messageNodes.get(messageID);
if (!message) {
return [];
}
const chain = [message];
let current = message;
while (current.parent) {
chain.unshift(current.parent);
current = current.parent;
}
return chain;
}
/**
* Method to serialize the message tree into an array of message nodes, excluding parent and replies properties.
* @returns {Omit<MessageNode, 'parent' | 'replies'>[]} - An array of serialized message nodes.
*/
public serialize(): Omit<MessageNode, 'parent' | 'replies'>[] {
return Array.from(this.messageNodes.values())
.map((messageNode) => {
const n: any = { ...messageNode };
delete n.parent;
delete n.replies;
return n;
});
}
/**
* Method to get the most recent leaf message in the message tree.
* @returns {MessageNode | null} - The most recent leaf message, or null if the tree is empty.
*/
public mostRecentLeaf(): MessageNode | null {
return this.leafs.sort((a, b) => b.timestamp - a.timestamp)[0] || null;
}
}

145
app/src/core/chat/openai.ts Normal file
View File

@@ -0,0 +1,145 @@
import EventEmitter from "events";
import { Configuration, OpenAIApi } from "openai";
import SSE from "../utils/sse";
import { OpenAIMessage, Parameters } from "./types";
import { backend } from "../backend";
export const defaultModel = 'gpt-3.5-turbo';
export function isProxySupported() {
return !!backend.current?.user?.services?.includes('openai');
}
function shouldUseProxy(apiKey: string | undefined | null) {
return !apiKey && isProxySupported();
}
function getEndpoint(proxied = false) {
return proxied ? '/chatapi/proxies/openai' : 'https://api.openai.com';
}
export interface OpenAIResponseChunk {
id?: string;
done: boolean;
choices?: {
delta: {
content: string;
};
index: number;
finish_reason: string | null;
}[];
model?: string;
}
function parseResponseChunk(buffer: any): OpenAIResponseChunk {
const chunk = buffer.toString().replace('data: ', '').trim();
if (chunk === '[DONE]') {
return {
done: true,
};
}
const parsed = JSON.parse(chunk);
return {
id: parsed.id,
done: false,
choices: parsed.choices,
model: parsed.model,
};
}
export async function createChatCompletion(messages: OpenAIMessage[], parameters: Parameters): Promise<string> {
const proxied = shouldUseProxy(parameters.apiKey);
const endpoint = getEndpoint(proxied);
if (!proxied && !parameters.apiKey) {
throw new Error('No API key provided');
}
const response = await fetch(endpoint + '/v1/chat/completions', {
method: "POST",
headers: {
'Accept': 'application/json, text/plain, */*',
'Authorization': !proxied ? `Bearer ${parameters.apiKey}` : '',
'Content-Type': 'application/json',
},
body: JSON.stringify({
"model": parameters.model,
"messages": messages,
"temperature": parameters.temperature,
}),
});
const data = await response.json();
return data.choices[0].message?.content?.trim() || '';
}
export async function createStreamingChatCompletion(messages: OpenAIMessage[], parameters: Parameters) {
const emitter = new EventEmitter();
const proxied = shouldUseProxy(parameters.apiKey);
const endpoint = getEndpoint(proxied);
if (!proxied && !parameters.apiKey) {
throw new Error('No API key provided');
}
const eventSource = new SSE(endpoint + '/v1/chat/completions', {
method: "POST",
headers: {
'Accept': 'application/json, text/plain, */*',
'Authorization': !proxied ? `Bearer ${parameters.apiKey}` : '',
'Content-Type': 'application/json',
},
payload: JSON.stringify({
"model": parameters.model,
"messages": messages,
"temperature": parameters.temperature,
"stream": true,
}),
}) as SSE;
let contents = '';
eventSource.addEventListener('error', (event: any) => {
if (!contents) {
let error = event.data;
try {
error = JSON.parse(error).error.message;
} catch (e) {}
emitter.emit('error', error);
}
});
eventSource.addEventListener('message', async (event: any) => {
if (event.data === '[DONE]') {
emitter.emit('done');
return;
}
try {
const chunk = parseResponseChunk(event.data);
if (chunk.choices && chunk.choices.length > 0) {
contents += chunk.choices[0]?.delta?.content || '';
emitter.emit('data', contents);
}
} catch (e) {
console.error(e);
}
});
eventSource.stream();
return {
emitter,
cancel: () => eventSource.close(),
};
}
export const maxTokensByModel = {
"chatgpt-3.5-turbo": 2048,
"gpt-4": 8096,
}

View File

@@ -0,0 +1,35 @@
import { defaultModel } from "./openai";
import { Parameters } from "./types";
export const defaultParameters: Parameters = {
temperature: 0.5,
model: defaultModel
};
export function loadParameters(id: string | null | undefined = null): Parameters {
const apiKey = localStorage.getItem('openai-api-key') || undefined;
const key = id ? `parameters-${id}` : 'parameters';
try {
const raw = localStorage.getItem(key);
if (raw) {
const parameters = JSON.parse(raw) as Parameters;
parameters.apiKey = apiKey;
return parameters;
}
} catch (e) { }
return id ? loadParameters() : { ...defaultParameters, apiKey };
}
export function saveParameters(id: string, parameters: Parameters) {
if (parameters) {
const apiKey = parameters.apiKey;
delete parameters.apiKey;
localStorage.setItem(`parameters-${id}`, JSON.stringify(parameters));
localStorage.setItem('parameters', JSON.stringify(parameters));
if (apiKey) {
localStorage.setItem(`openai-api-key`, apiKey);
}
}
}

View File

@@ -0,0 +1,63 @@
import { MessageTree } from "./message-tree";
export interface Parameters {
temperature: number;
apiKey?: string;
initialSystemPrompt?: string;
model: string;
}
export interface Message {
id: string;
chatID: string;
parentID?: string;
timestamp: number;
role: string;
model?: string;
content: string;
parameters?: Parameters;
done?: boolean;
}
export interface UserSubmittedMessage {
chatID: string;
parentID?: string;
content: string;
requestedParameters: Parameters;
}
export interface OpenAIMessage {
role: string;
content: string;
}
export function getOpenAIMessageFromMessage(message: Message): OpenAIMessage {
return {
role: message.role,
content: message.content,
};
}
export interface Chat {
id: string;
messages: MessageTree;
metadata?: Record<string, any>;
pluginOptions?: Record<string, any>;
title?: string | null;
created: number;
updated: number;
deleted?: boolean;
}
export function serializeChat(chat: Chat): string {
return JSON.stringify({
...chat,
messages: chat.messages.serialize(),
});
}
export function deserializeChat(serialized: string) {
const chat = JSON.parse(serialized);
chat.messages = new MessageTree(chat.messages);
return chat as Chat;
}

View File

@@ -0,0 +1,76 @@
import { useCallback, useEffect, useState } from "react";
import { backend } from "../backend";
import { ChatManager } from "..";
import { Chat, Message } from './types';
export interface UseChatResult {
chat: Chat | null | undefined;
chatLoadedAt: number;
messages: Message[];
messagesToDisplay: Message[];
leaf: Message | null | undefined;
}
export function useChat(chatManager: ChatManager, id: string | undefined | null, share = false): UseChatResult {
const [chat, setChat] = useState<Chat | null | undefined>(null);
const [_, setVersion] = useState(0); // eslint-disable-line @typescript-eslint/no-unused-vars
// used to prevent auto-scroll when chat is first opened
const [chatLoadedAt, setLoadedAt] = useState(0);
const update = useCallback(async () => {
if (id) {
if (!share) {
const c = chatManager.get(id);
if (c) {
setChat(c);
setVersion(v => v + 1);
return;
}
} else {
const c = await backend.current?.getSharedChat(id);
if (c) {
setChat(c);
setVersion(v => v + 1);
return;
}
}
}
setChat(null);
}, [id, share, chatManager]);
useEffect(() => {
if (id) {
update();
chatManager.on(id, update);
setChat(chatManager.get(id));
setLoadedAt(Date.now());
} else {
setChat(null);
setLoadedAt(0);
}
return () => {
if (id) {
chatManager.off(id, update);
}
};
}, [id, update, chatManager]);
const leaf = chat?.messages.mostRecentLeaf();
let messages: Message[] = [];
let messagesToDisplay: Message[] = [];
if (leaf) {
messages = (chat?.messages.getMessageChainTo(leaf?.id) || []);
messagesToDisplay = messages.filter(m => ['user', 'assistant'].includes(m.role)) || [];
}
return {
chat,
chatLoadedAt,
messages,
messagesToDisplay,
leaf,
};
}

307
app/src/core/chat/y-chat.ts Normal file
View File

@@ -0,0 +1,307 @@
import * as Y from 'yjs';
import { Chat, Message } from './types';
import EventEmitter from 'events';
import { v4 as uuidv4 } from 'uuid';
import { MessageTree } from './message-tree';
const METADATA_KEY = 'metadata';
const IMPORTED_METADATA_KEY = 'imported-metadata';
const PLUGIN_OPTIONS_KEY = 'plugin-options';
const MESSAGES_KEY = 'messages';
const CONTENT_KEY = 'messages:content';
const DONE_KEY = 'messages:done';
export class YChat {
private callback: any;
private pendingContent = new Map<string, string>();
private prefix = 'chat.' + this.id + '.';
public static from(root: Y.Doc, id: string) {
// const id = data.get('metadata').get('id') as string;
return new YChat(id, root);
}
constructor(public readonly id: string, public root: Y.Doc) {
this.purgeDeletedValues();
}
public observeDeep(callback: any) {
this.callback = callback;
this.metadata?.observeDeep(callback);
this.pluginOptions?.observeDeep(callback);
this.messages?.observeDeep(callback);
this.content?.observeDeep(callback);
this.done?.observeDeep(callback);
}
public get deleted(): boolean {
return this.metadata.get('deleted') || false;
}
public get metadata(): Y.Map<any> {
return this.root.getMap<any>(this.prefix + METADATA_KEY);
}
public get importedMetadata(): Y.Map<any> {
return this.root.getMap<any>(this.prefix + IMPORTED_METADATA_KEY);
}
public get pluginOptions(): Y.Map<any> {
return this.root.getMap<any>(this.prefix + PLUGIN_OPTIONS_KEY);
}
public get messages(): Y.Map<Message> {
return this.root.getMap<Message>(this.prefix + MESSAGES_KEY);
}
public get content(): Y.Map<Y.Text> {
return this.root.getMap<Y.Text>(this.prefix + CONTENT_KEY);
}
public get done(): Y.Map<boolean> {
return this.root.getMap<boolean>(this.prefix + DONE_KEY);
}
public get title() {
return (this.metadata.get('title') as string) || (this.importedMetadata.get('title') as string) || null;
}
public set title(value: string | null) {
if (value) {
this.metadata.set('title', value);
}
}
public setPendingMessageContent(messageID: string, value: string) {
this.pendingContent.set(messageID, value);
this.callback?.();
}
public setMessageContent(messageID: string, value: string) {
this.pendingContent.delete(messageID);
this.content.set(messageID, new Y.Text(value));
}
public getMessageContent(messageID: string) {
return this.pendingContent.get(messageID) || this.content.get(messageID)?.toString() || "";
}
public onMessageDone(messageID: string) {
this.done.set(messageID, true);
}
public getOption(pluginID: string, optionID: string): any {
const key = pluginID + "." + optionID;
return this.pluginOptions?.get(key) || null;
}
public setOption(pluginID: string, optionID: string, value: any) {
const key = pluginID + "." + optionID;
return this.pluginOptions.set(key, value);
}
public hasOption(pluginID: string, optionID: string) {
const key = pluginID + "." + optionID;
return this.pluginOptions.has(key);
}
public delete() {
if (!this.deleted) {
this.metadata.clear();
this.pluginOptions.clear();
this.messages.clear();
this.content.clear();
this.done.clear();
} else {
this.purgeDeletedValues();
}
}
private purgeDeletedValues() {
if (this.deleted) {
if (this.metadata.size > 1) {
for (const key of Array.from(this.metadata.keys())) {
if (key !== 'deleted') {
this.metadata.delete(key);
}
}
}
if (this.pluginOptions.size > 0) {
this.pluginOptions.clear();
}
if (this.messages.size > 0) {
this.messages.clear();
}
if (this.content.size > 0) {
this.content.clear();
}
if (this.done.size > 0) {
this.done.clear();
}
}
}
}
export class YChatDoc extends EventEmitter {
public root = new Y.Doc();
// public chats = this.root.getMap<Y.Map<any>>('chats');
// public deletedChatIDs = this.root.getArray<string>('deletedChatIDs');
public deletedChatIDsSet = new Set<string>();
public options = this.root.getMap<Y.Map<any>>('options');
private yChats = new Map<string, YChat>();
private observed = new Set<string>();
constructor() {
super();
this.root.whenLoaded.then(() => {
const chatIDs = Array.from(this.root.getMap('chats').keys());
for (const id of chatIDs) {
this.observeChat(id);
}
});
}
private observeChat(id: string, yChat = this.getYChat(id)) {
if (!this.observed.has(id)) {
yChat?.observeDeep(() => this.emit('update', id));
this.observed.add(id);
}
}
// public set(id: string, chat: YChat) {
// this.chats.set(id, chat.data);
// if (!this.observed.has(id)) {
// this.getYChat(id)?.observeDeep(() => this.emit('update', id));
// this.observed.add(id);
// }
// }
public get chatIDMap() {
return this.root.getMap('chatIDs');
}
public getYChat(id: string, expectContent = false) {
let yChat = this.yChats.get(id);
if (!yChat) {
yChat = YChat.from(this.root, id);
this.yChats.set(id, yChat);
}
if (expectContent && !this.chatIDMap.has(id)) {
this.chatIDMap.set(id, true);
}
this.observeChat(id, yChat);
return yChat;
}
public delete(id: string) {
this.getYChat(id)?.delete();
}
public has(id: string) {
return this.chatIDMap.has(id) && !YChat.from(this.root, id).deleted;
}
public getChatIDs() {
return Array.from(this.chatIDMap.keys());
}
public getAllYChats() {
return this.getChatIDs().map(id => this.getYChat(id)!);
}
public transact(cb: () => void) {
return this.root.transact(cb);
}
public addMessage(message: Message) {
const chat = this.getYChat(message.chatID, true);
if (!chat) {
throw new Error('Chat not found');
}
this.transact(() => {
chat.messages.set(message.id, {
...message,
content: '',
});
chat.content.set(message.id, new Y.Text(message.content || ''));
if (message.done) {
chat.done.set(message.id, message.done);
}
});
}
public createYChat(id = uuidv4()) {
// return new YChat(id, this.root);
// this.set(id, chat);
return id;
}
public getMessageTree(chatID: string): MessageTree {
const tree = new MessageTree();
const chat = this.getYChat(chatID);
chat?.messages?.forEach(m => {
try {
const content = chat.getMessageContent(m.id);
const done = chat.done.get(m.id) || false;
tree.addMessage(m, content, done);
} catch (e) {
console.warn(`failed to load message ${m.id}`, e);
}
});
return tree;
}
public getMessagesPrecedingMessage(chatID: string, messageID: string) {
const tree = this.getMessageTree(chatID);
const message = tree.get(messageID);
if (!message) {
throw new Error("message not found: " + messageID);
}
const messages: Message[] = message.parentID
? tree.getMessageChainTo(message.parentID)
: [];
return messages;
}
public getChat(id: string): Chat {
const chat = this.getYChat(id);
const tree = this.getMessageTree(id);
return {
id,
messages: tree,
title: chat.title,
metadata: {
...chat.importedMetadata.toJSON(),
...chat.metadata.toJSON(),
},
pluginOptions: chat?.pluginOptions?.toJSON() || {},
deleted: !chat.deleted,
created: tree.first?.timestamp || 0,
updated: tree.mostRecentLeaf()?.timestamp || 0,
}
}
public getOption(pluginID: string, optionID: string): any {
const key = pluginID + "." + optionID;
return this.options.get(key);
}
public setOption(pluginID: string, optionID: string, value: any) {
const key = pluginID + "." + optionID;
return this.options.set(key, value);
}
}

256
app/src/core/context.tsx Normal file
View File

@@ -0,0 +1,256 @@
import React, { useState, useRef, useMemo, useEffect, useCallback } from "react";
import { v4 as uuidv4 } from 'uuid';
import { IntlShape, useIntl } from "react-intl";
import { Backend, User } from "./backend";
import { ChatManager } from "./";
import { useAppDispatch } from "../store";
import { openOpenAIApiKeyPanel } from "../store/settings-ui";
import { Message, Parameters } from "./chat/types";
import { useChat, UseChatResult } from "./chat/use-chat";
import { TTSContextProvider } from "./tts/use-tts";
import { useLocation, useNavigate, useParams } from "react-router-dom";
import { isProxySupported } from "./chat/openai";
import { audioContext, resetAudioContext } from "./tts/audio-file-player";
export interface Context {
authenticated: boolean;
sessionExpired: boolean;
chat: ChatManager;
user: User | null;
intl: IntlShape;
id: string | undefined | null;
currentChat: UseChatResult;
isHome: boolean;
isShare: boolean;
generating: boolean;
onNewMessage: (message?: string) => Promise<string | false>;
regenerateMessage: (message: Message) => Promise<boolean>;
editMessage: (message: Message, content: string) => Promise<boolean>;
}
const AppContext = React.createContext<Context>({} as any);
const chatManager = new ChatManager();
const backend = new Backend(chatManager);
let intl: IntlShape;
export function useCreateAppContext(): Context {
const { id: _id } = useParams();
const [nextID, setNextID] = useState(uuidv4());
const id = _id ?? nextID;
const dispatch = useAppDispatch();
intl = useIntl();
const { pathname } = useLocation();
const isHome = pathname === '/';
const isShare = pathname.startsWith('/s/');
const currentChat = useChat(chatManager, id, isShare);
const [authenticated, setAuthenticated] = useState(backend?.isAuthenticated || false);
const [wasAuthenticated, setWasAuthenticated] = useState(backend?.isAuthenticated || false);
useEffect(() => {
chatManager.on('y-update', update => backend?.receiveYUpdate(update))
}, []);
const updateAuth = useCallback((authenticated: boolean) => {
setAuthenticated(authenticated);
if (authenticated && backend.user) {
chatManager.login(backend.user.email || backend.user.id);
}
if (authenticated) {
setWasAuthenticated(true);
localStorage.setItem('registered', 'true');
}
}, []);
useEffect(() => {
updateAuth(backend?.isAuthenticated || false);
backend?.on('authenticated', updateAuth);
return () => {
backend?.off('authenticated', updateAuth)
};
}, [updateAuth]);
const onNewMessage = useCallback(async (message?: string) => {
resetAudioContext();
if (isShare) {
return false;
}
if (!message?.trim().length) {
return false;
}
// const openaiApiKey = store.getState().apiKeys.openAIApiKey;
const openaiApiKey = chatManager.options.getOption<string>('openai', 'apiKey');
if (!openaiApiKey && !isProxySupported()) {
dispatch(openOpenAIApiKeyPanel());
return false;
}
const parameters: Parameters = {
model: chatManager.options.getOption<string>('parameters', 'model', id),
temperature: chatManager.options.getOption<number>('parameters', 'temperature', id),
};
if (id === nextID) {
setNextID(uuidv4());
const autoPlay = chatManager.options.getOption<boolean>('tts', 'autoplay');
if (autoPlay) {
const ttsService = chatManager.options.getOption<string>('tts', 'service');
if (ttsService === 'web-speech') {
const utterance = new SpeechSynthesisUtterance('Generating');
utterance.volume = 0;
speechSynthesis.speak(utterance);
}
}
}
// if (chatManager.has(id)) {
// chatManager.sendMessage({
// chatID: id,
// content: message.trim(),
// requestedParameters: {
// ...parameters,
// apiKey: openaiApiKey,
// },
// parentID: currentChat.leaf?.id,
// });
// } else {
// await chatManager.createChat(id);
chatManager.sendMessage({
chatID: id,
content: message.trim(),
requestedParameters: {
...parameters,
apiKey: openaiApiKey,
},
parentID: currentChat.leaf?.id,
});
// }
return id;
}, [dispatch, id, currentChat.leaf, isShare]);
const regenerateMessage = useCallback(async (message: Message) => {
resetAudioContext();
if (isShare) {
return false;
}
// const openaiApiKey = store.getState().apiKeys.openAIApiKey;
const openaiApiKey = chatManager.options.getOption<string>('openai', 'apiKey');
if (!openaiApiKey && !isProxySupported()) {
dispatch(openOpenAIApiKeyPanel());
return false;
}
const parameters: Parameters = {
model: chatManager.options.getOption<string>('parameters', 'model', id),
temperature: chatManager.options.getOption<number>('parameters', 'temperature', id),
};
await chatManager.regenerate(message, {
...parameters,
apiKey: openaiApiKey,
});
return true;
}, [dispatch, isShare]);
const editMessage = useCallback(async (message: Message, content: string) => {
resetAudioContext();
if (isShare) {
return false;
}
if (!content?.trim().length) {
return false;
}
// const openaiApiKey = store.getState().apiKeys.openAIApiKey;
const openaiApiKey = chatManager.options.getOption<string>('openai', 'apiKey');
if (!openaiApiKey && !isProxySupported()) {
dispatch(openOpenAIApiKeyPanel());
return false;
}
const parameters: Parameters = {
model: chatManager.options.getOption<string>('parameters', 'model', id),
temperature: chatManager.options.getOption<number>('parameters', 'temperature', id),
};
if (id && chatManager.has(id)) {
await chatManager.sendMessage({
chatID: id,
content: content.trim(),
requestedParameters: {
...parameters,
apiKey: openaiApiKey,
},
parentID: message.parentID,
});
} else {
const id = await chatManager.createChat();
await chatManager.sendMessage({
chatID: id,
content: content.trim(),
requestedParameters: {
...parameters,
apiKey: openaiApiKey,
},
parentID: message.parentID,
});
}
return true;
}, [dispatch, id, isShare]);
const generating = currentChat?.messagesToDisplay?.length > 0
? !currentChat.messagesToDisplay[currentChat.messagesToDisplay.length - 1].done
: false;
const context = useMemo<Context>(() => ({
authenticated,
sessionExpired: !authenticated && wasAuthenticated,
id,
user: backend.user,
intl,
chat: chatManager,
currentChat,
isHome,
isShare,
generating,
onNewMessage,
regenerateMessage,
editMessage,
}), [authenticated, wasAuthenticated, generating, onNewMessage, regenerateMessage, editMessage, currentChat, id, isHome, isShare, intl]);
return context;
}
export function useAppContext() {
return React.useContext(AppContext);
}
export function AppContextProvider(props: { children: React.ReactNode }) {
const context = useCreateAppContext();
return <AppContext.Provider value={context}>
<TTSContextProvider>
{props.children}
</TTSContextProvider>
</AppContext.Provider>;
}

254
app/src/core/index.ts Normal file
View File

@@ -0,0 +1,254 @@
import { BroadcastChannel } from 'broadcast-channel';
import EventEmitter from 'events';
import { v4 as uuidv4 } from 'uuid';
import { Chat, Message, Parameters, UserSubmittedMessage } from './chat/types';
import * as Y from 'yjs';
import { IndexeddbPersistence } from 'y-indexeddb';
import { YChatDoc } from './chat/y-chat';
import { loadFromPreviousVersion as loadSavedChatsFromPreviousVersion } from './chat/chat-persistance';
import { Search } from './search';
import { ReplyRequest } from './chat/create-reply';
import { OptionsManager } from './options';
import { Option } from './options/option';
import { pluginMetadata } from './plugins/metadata';
import { pluginRunner } from "./plugins/plugin-runner";
import { createBasicPluginContext } from './plugins/plugin-context';
export const channel = new BroadcastChannel('chats');
export class ChatManager extends EventEmitter {
public doc!: YChatDoc;
private provider!: IndexeddbPersistence;
private search!: Search;
public options!: OptionsManager;
private username: string | null = "anonymous";
private activeReplies = new Map<string, ReplyRequest>();
private changedIDs = new Set<string>();
public lastReplyID: string | null = null;
constructor() {
super();
this.setMaxListeners(1000);
console.log('initializing chat manager');
this.doc = this.attachYDoc('anonymous');
loadSavedChatsFromPreviousVersion(this.doc)
.then(() => this.emit('update'));
setInterval(() => this.emitChanges());
channel.onmessage = message => {
if (message.type === 'y-update') {
this.applyYUpdate(message.data);
}
};
(window as any).chat = this;
}
public login(username: string) {
if (username && this.username !== username) {
this.username = username;
this.attachYDoc(username);
}
}
private attachYDoc(username: string) {
console.log('attaching y-doc for ' + username);
// detach current doc
const doc = this.doc as YChatDoc | undefined;
const provider = this.provider as IndexeddbPersistence | undefined;
doc?.removeAllListeners();
const pluginOptionsManager = this.options as OptionsManager | undefined;
pluginOptionsManager?.destroy();
// attach new doc
this.doc = new YChatDoc();
this.doc.on('update', chatID => this.changedIDs.add(chatID));
this.doc.root.on('update', (update, origin) => {
if (!(origin instanceof IndexeddbPersistence) && origin !== 'sync') {
this.emit('y-update', update);
channel.postMessage({ type: 'y-update', data: update });
} else {
console.log("IDB/sync update");
}
});
this.search = new Search(this);
// connect new doc to persistance, scoped to the current username
this.provider = new IndexeddbPersistence('chats:' + username, this.doc.root);
this.provider.whenSynced.then(() => {
this.doc.getChatIDs().map(id => this.emit(id));
this.emit('update');
});
this.options = new OptionsManager(this.doc, pluginMetadata);
this.options.on('update', (...args) => this.emit('plugin-options-update', ...args));
pluginRunner(
'init',
pluginID => createBasicPluginContext(pluginID, this.options),
plugin => plugin.initialize(),
);
if (username !== 'anonymous') {
// import chats from the anonymous doc after signing in
provider?.whenSynced.then(() => {
if (doc) {
Y.applyUpdate(this.doc.root, Y.encodeStateAsUpdate(doc.root));
setTimeout(() => provider.clearData(), 10 * 1000);
}
});
}
return this.doc;
}
public applyYUpdate(update: Uint8Array) {
Y.applyUpdate(this.doc.root, update);
}
private emitChanges() {
const ids = Array.from(this.changedIDs);
this.changedIDs.clear();
for (const id of ids) {
this.emit(id);
this.search.update(id);
}
if (ids.length) {
this.emit('update');
}
}
public async sendMessage(userSubmittedMessage: UserSubmittedMessage) {
const chat = this.doc.getYChat(userSubmittedMessage.chatID);
if (!chat) {
throw new Error('Chat not found');
}
const message: Message = {
id: uuidv4(),
parentID: userSubmittedMessage.parentID,
chatID: userSubmittedMessage.chatID,
timestamp: Date.now(),
role: 'user',
content: userSubmittedMessage.content,
done: true,
};
this.doc.addMessage(message);
const messages: Message[] = this.doc.getMessagesPrecedingMessage(message.chatID, message.id);
messages.push(message);
await this.getReply(messages, userSubmittedMessage.requestedParameters);
}
public async regenerate(message: Message, requestedParameters: Parameters) {
const messages = this.doc.getMessagesPrecedingMessage(message.chatID, message.id);
await this.getReply(messages, requestedParameters);
}
private async getReply(messages: Message[], requestedParameters: Parameters) {
const latestMessage = messages[messages.length - 1];
const chatID = latestMessage.chatID;
const parentID = latestMessage.id;
const chat = this.doc.getYChat(latestMessage.chatID);
if (!chat) {
throw new Error('Chat not found');
}
const message: Message = {
id: uuidv4(),
parentID,
chatID,
timestamp: Date.now(),
role: 'assistant',
model: requestedParameters.model,
content: '',
};
this.lastReplyID = message.id;
this.doc.addMessage(message);
const request = new ReplyRequest(this.get(chatID), chat, messages, message.id, requestedParameters, this.options);
request.on('done', () => this.activeReplies.delete(message.id));
request.execute();
this.activeReplies.set(message.id, request);
}
public cancelReply(chatID: string | undefined, id: string) {
this.activeReplies.get(id)?.onCancel();
this.activeReplies.delete(id);
}
public async createChat(id?: string): Promise<string> {
return this.doc.createYChat(id);
}
public get(id: string): Chat {
return this.doc.getChat(id);
}
public has(id: string) {
return this.doc.has(id);
}
public all(): Chat[] {
return this.doc.getChatIDs().map(id => this.get(id));
}
public deleteChat(id: string, broadcast = true) {
this.doc.delete(id);
this.search.delete(id);
}
public searchChats(query: string) {
return this.search.query(query);
}
public getPluginOptions(chatID?: string) {
const pluginOptions: Record<string, Record<string, any>> = {};
for (const description of pluginMetadata) {
pluginOptions[description.id] = this.options.getAllOptions(description.id, chatID);
}
return pluginOptions;
}
public setPluginOption(pluginID: string, optionID: string, value: any, chatID?: string) {
this.options.setOption(pluginID, optionID, value, chatID);
}
public resetPluginOptions(pluginID: string, chatID?: string | null) {
this.options.resetOptions(pluginID, chatID);
}
public getQuickSettings(): Array<{ groupID: string, option: Option }> {
const options = this.options.getAllOptions('quick-settings');
return Object.keys(options)
.filter(key => options[key])
.map(key => {
const groupID = key.split('--')[0];
const optionID = key.split('--')[1];
return {
groupID,
option: this.options.findOption(groupID, optionID)!,
};
})
.filter(o => !!o.option);
}
}

View File

@@ -0,0 +1,209 @@
import { EventEmitter } from "events";
import { PluginDescription } from "../plugins/plugin-description";
import { Option } from "./option";
import { YChat, YChatDoc } from "../chat/y-chat";
import { globalOptions } from "../../global-options";
import { OptionGroup } from "./option-group";
import { BroadcastChannel } from "broadcast-channel";
export const broadcastChannel = new BroadcastChannel("options");
function cacheKey(groupID: string, optionID: string, chatID?: string | null) {
return chatID ? `${chatID}.${groupID}.${optionID}` : `${groupID}.${optionID}`;
}
export class OptionsManager extends EventEmitter {
private optionGroups: OptionGroup[];
private optionsCache: Map<string, any> = new Map();
constructor(private yDoc: YChatDoc, private pluginMetadata: PluginDescription[]) {
super();
this.optionGroups = [...globalOptions, ...this.pluginMetadata];
// Load options from localStorage and YChats
this.loadOptions();
// Listen for update events on the broadcast channel
broadcastChannel.onmessage = (event: MessageEvent) => {
this.loadOptions();
if (event.data?.groupID) {
this.emit('update', event.data.groupID);
}
};
}
private loadOption(groupID: string, option: Option, yChat?: YChat) {
if (option.scope === "chat") {
const key: string = cacheKey(groupID, option.id, yChat?.id);
let value: string | undefined | null;
if (yChat) {
value = yChat.getOption(groupID, option.id);
}
// Fallback to localStorage if value is not found in YChat
if (typeof value === 'undefined' || value === null) {
const fallbackKey = cacheKey(groupID, option.id);
const raw = localStorage.getItem(fallbackKey);
value = raw ? JSON.parse(raw) : option.defaultValue;
}
this.optionsCache.set(key, value);
} else if (option.scope === "user") {
const key = cacheKey(groupID, option.id);
const value = this.yDoc.getOption(groupID, option.id) || option.defaultValue;
this.optionsCache.set(key, value);
} else {
const key = cacheKey(groupID, option.id);
const raw = localStorage.getItem(key);
const value = raw ? JSON.parse(raw) : option.defaultValue;
this.optionsCache.set(key, value);
}
}
private loadOptions() {
// Load browser and user-scoped options
this.optionGroups.forEach(group => {
group.options.forEach(option => {
this.loadOption(group.id, option);
});
});
// Load chat-scoped options from YChats
this.yDoc.getChatIDs().forEach(chatID => {
const yChat = this.yDoc.getYChat(chatID)!;
this.optionGroups.forEach(group => {
group.options.forEach(option => {
if (option.scope === "chat") {
this.loadOption(group.id, option, yChat);
}
});
});
});
(window as any).options = this;
this.emit("update");
}
public resetOptions(groupID: string, chatID?: string | null) {
console.log(`resetting ${groupID} options with chatID = ${chatID}`);
const group = this.optionGroups.find(group => group.id === groupID);
group?.options.forEach(option => {
if (option.resettable) {
this.setOption(group.id, option.id, option.defaultValue, option.scope === 'chat' ? chatID : null);
}
});
}
public getAllOptions(groupID: string, chatID?: string | null): Record<string, any> {
const options: Record<string, any> = {};
const group = this.optionGroups.find(group => group.id === groupID);
group?.options.forEach(option => {
options[option.id] = this.getOption(groupID, option.id, chatID);
});
return options;
}
public getOption<T=any>(groupID: string, optionID: string, chatID?: string | null, validate = false): T {
const option = this.findOption(groupID, optionID);
if (!option) {
throw new Error(`option not found (group = ${groupID}), option = ${optionID}`);
}
const key = cacheKey(groupID, optionID, option.scope === 'chat' ? chatID : null);
let value = this.optionsCache.get(key);
if (typeof value !== 'undefined' && value !== null) {
if (validate) {
const valid = !option.validate || option.validate(value, this);
if (valid) {
return value;
}
} else {
return value;
}
}
const fallbackKey = cacheKey(groupID, optionID);
value = this.optionsCache.get(fallbackKey);
if (typeof value !== 'undefined' && value !== null) {
if (validate) {
const valid = !option.validate || option.validate(value, this);
if (valid) {
return value;
}
} else {
return value;
}
}
return option.defaultValue;
}
public getValidatedOption(groupID: string, optionID: string, chatID?: string | null): any {
return this.getOption(groupID, optionID, chatID, true);
}
public setOption(groupID: string, optionID: string, value: any, chatID?: string | null) {
const option = this.findOption(groupID, optionID);
if (!option) {
console.warn(`option not found (group = ${groupID}), option = ${optionID}`);
return;
}
const key = cacheKey(groupID, optionID, option.scope === 'chat' ? chatID : null);
value = value ?? null;
if (option.scope === "chat") {
if (!chatID) {
console.warn(`cannot set option for chat without chatID (group = ${groupID}), option = ${optionID}, chatID = ${chatID}`);
return;
}
const yChat = this.yDoc.getYChat(chatID);
yChat?.setOption(groupID, optionID, value);
const fallbackKey = cacheKey(groupID, optionID);
localStorage.setItem(fallbackKey, JSON.stringify(value));
} else if (option.scope === 'user') {
this.yDoc.setOption(groupID, optionID, value);
} else {
localStorage.setItem(key, JSON.stringify(value));
}
console.log(`setting ${groupID}.${optionID} = ${value} (${typeof value})`)
// Update cache and emit update event
this.optionsCache.set(key, value);
this.emit("update", groupID, optionID);
// Notify other tabs through the broadcast channel
broadcastChannel.postMessage({ groupID, optionID });
}
public findOption(groupID: string, optionID: string): Option | undefined {
const group = this.optionGroups.find(group => group.id === groupID);
const option = group?.options.find(option => option.id === optionID);
if (option) {
return option;
}
console.warn("couldn't find option " + groupID + "." + optionID);
return undefined;
}
public destroy() {
this.removeAllListeners();
broadcastChannel.onmessage = null;
}
}

View File

@@ -0,0 +1,20 @@
import { Option } from "./option";
import type { OptionsManager } from ".";
import { ReactNode } from "react";
/**
* @interface OptionGroup
* @description Represents a group of options within the OptionsManager. Each group is identified by a unique ID and can have a name, description, and a set of options. The group can be hidden based on a boolean value or a function that evaluates the visibility condition using the OptionsManager instance.
* @property {string} id - The unique identifier for the option group.
* @property {string} [name] - The display name for the option group.
* @property {string | ReactNode} [description] - A description for the option group, which can be a string or a ReactNode.
* @property {boolean | ((options: OptionsManager) => boolean)} [hidden] - Determines if the option group should be hidden. Can be a boolean value or a function that returns a boolean value based on the OptionsManager instance.
* @property {Option[]} options - An array of options within the group.
*/
export interface OptionGroup {
id: string;
name?: string;
description?: string | ReactNode;
hidden?: boolean | ((options: OptionsManager) => boolean);
options: Option[];
}

View File

@@ -0,0 +1,34 @@
import type { OptionsManager } from ".";
import { Context } from "../context";
import { RenderProps, RenderPropsBuilder } from "./render-props";
/**
* Represents an option in the settings UI.
* @typedef {Object} Option
* @property {string} id - The unique identifier for the option.
* @property {any} defaultValue - The default value for the option.
* @property {'speech' | 'chat' | 'user'} tab - The tab ID in the settings UI where the option will be displayed.
* @property {boolean} [resettable] - Whether the option can be reset to its default value.
* @property {'chat' | 'user' | 'browser'} [scope] - Determines how the option value is saved (browser = local storage, user = synced to the user's account across devices, chat = saved for specific chat).
* @property {boolean} [displayAsSeparateSection] - Whether the option should be displayed inline in the settings UI or as a 'block' with a heading and separate section.
* @property {RenderProps | RenderPropsBuilder} renderProps - Customizes the appearance of the option's UI in the settings UI, and can see other options and app state.
* @property {(value: any, options: OptionsManager) => boolean} [validate] - If this function returns false, the defaultValue will be used instead.
*/
export interface Option {
id: string;
defaultValue: any;
scope?: 'chat' | 'user' | 'browser';
displayOnSettingsScreen: 'speech' | 'chat' | 'plugins' | 'ui' | 'user';
displayAsSeparateSection?: boolean;
resettable?: boolean;
renderProps: RenderProps | RenderPropsBuilder;
validate?: (value: any, options: OptionsManager) => boolean;
displayInQuickSettings?: {
name: string;
displayByDefault?: boolean;
label: string | ((value: any, options: OptionsManager, context: Context) => string);
};
}

View File

@@ -0,0 +1,41 @@
import type { OptionsManager } from ".";
import type { Context } from "../context";
/**
* Represents the properties used to render an option in the settings UI.
* @typedef {Object} RenderProps
* @property {'text' | 'textarea' | 'select' | 'number' | 'slider' | 'checkbox'} type - The type of input for the option.
* @property {any} [label] - The label for the option.
* @property {any} [description] - The description for the option.
* @property {any} [placeholder] - The placeholder for the option.
* @property {boolean} [disabled] - Whether the option is disabled in the settings UI.
* @property {boolean} [hidden] - Whether the option is hidden in the settings UI.
* @property {number} [step] - The step value for number and slider inputs.
* @property {number} [min] - The minimum value for number and slider inputs.
* @property {number} [max] - The maximum value for number and slider inputs.
* @property {Array<{ label: string; value: string; }>} [options] - The options for the select input.
*/
export interface RenderProps {
type: 'text' | 'textarea' | 'select' | 'number' | 'slider' | 'checkbox' | 'password';
label?: any;
description?: any;
placeholder?: any;
disabled?: boolean;
hidden?: boolean;
// Number and slider input properties
step?: number;
min?: number;
max?: number;
// Select input options property
options?: Array<{ label: string; value: string; }>;
}
/**
* Represents a function that builds RenderProps based on the current value, options, and context.
* @typedef {(value: any, options: OptionsManager, context: Context) => RenderProps} RenderPropsBuilder
*/
export type RenderPropsBuilder = ((value: any, options: OptionsManager, context: Context) => RenderProps);

View File

@@ -0,0 +1,45 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { Context, useAppContext } from "../context";
import { RenderProps } from "./render-props";
export function useOption<T=any>(groupID: string, optionID: string, chatID?: string): [T, (value: T) => void, RenderProps, number] {
const context = useAppContext();
const [value, setValue] = useState(context.chat.options.getValidatedOption(groupID, optionID, chatID));
const [version, setVersion] = useState(0);
const timer = useRef<any>();
const onUpdate = useCallback((updatedGroupID: string) => {
if (groupID === updatedGroupID) {
setValue(context.chat.options.getValidatedOption(groupID, optionID, chatID));
setVersion(v => v + 1);
} else {
clearTimeout(timer.current);
timer.current = setTimeout(() => {
setValue(context.chat.options.getValidatedOption(groupID, optionID, chatID));
setVersion(v => v + 1);
}, 500);
}
}, [groupID, optionID, chatID]);
useEffect(() => {
context.chat.on('plugin-options-update', onUpdate);
return () => {
context.chat.off('plugin-options-update', onUpdate);
};
}, [chatID, onUpdate]);
const setOptionValue = useCallback((value: any) => {
context.chat.options.setOption(groupID, optionID, value, chatID);
}, [groupID, optionID, chatID]);
const option = context.chat.options.findOption(groupID, optionID)!;
return [
value,
setOptionValue,
typeof option.renderProps === 'function' ? option.renderProps(value, context.chat.options, context) : option.renderProps,
version,
];
}

View File

@@ -0,0 +1,10 @@
import { Context } from "../context";
import { OptionsManager } from "../options";
export interface Command {
name: string;
params: Array<{ name: string, type: string }>
returnType: string;
run: any;
disabled?: (options: OptionsManager, context: Context) => boolean;
}

View File

@@ -0,0 +1,30 @@
import { OpenAIMessage, Parameters } from "../chat/types";
import { PluginContext } from "./plugin-context";
import { PluginDescription } from "./plugin-description";
export default class Plugin<T=any> {
constructor(public context?: PluginContext) {
}
async initialize() {
}
describe(): PluginDescription {
throw new Error('not implemented');
}
get options(): T | undefined {
return this.context?.getOptions();
}
async preprocessModelInput(messages: OpenAIMessage[], parameters: Parameters): Promise<{
messages: OpenAIMessage[],
parameters: Parameters,
}> {
return { messages, parameters };
}
async postprocessModelOutput(message: OpenAIMessage, context: OpenAIMessage[], parameters: Parameters, done: boolean): Promise<OpenAIMessage> {
return message;
}
}

View File

@@ -0,0 +1,15 @@
import type { PluginDescription } from "./plugin-description";
import TTSPlugin from "../tts/tts-plugin";
import { registeredPlugins } from "../../plugins";
export const pluginMetadata: Array<PluginDescription> = registeredPlugins.map(p => new p().describe());
export const pluginIDs: string[] = pluginMetadata.map(d => d.id);
export const ttsPlugins = registeredPlugins.filter(p => {
const instance = new p();
return instance instanceof TTSPlugin;
});
export function getPluginByName(name: string) {
return registeredPlugins.find(p => new p().describe().name === name);
}

View File

@@ -0,0 +1,18 @@
import { Chat, OpenAIMessage, Parameters } from "../chat/types";
import { OptionsManager } from "../options";
export interface PluginContext {
getOptions(): any;
getCurrentChat(): Chat;
createChatCompletion(messages: OpenAIMessage[], parameters: Parameters): Promise<string>;
setChatTitle(title: string): Promise<void>;
}
export function createBasicPluginContext(pluginID: string, pluginOptions: OptionsManager, chatID?: string | null, chat?: Chat | null) {
return {
getOptions: (_pluginID = pluginID) => pluginOptions.getAllOptions(_pluginID, chatID),
getCurrentChat: () => chat,
createChatCompletion: async () => '',
setChatTitle: async (title: string) => { },
} as PluginContext;
}

View File

@@ -0,0 +1,9 @@
import type { Command } from "./command";
import type { OptionGroup } from "../options/option-group";
export interface PluginDescription extends OptionGroup {
name: string;
commands?: Command[];
category?: "internal" | "knowledge-sources" | "tts";
}

View File

@@ -0,0 +1,24 @@
import type { PluginContext } from "./plugin-context";
import type Plugin from ".";
import { pluginMetadata } from "./metadata";
import { registeredPlugins } from "../../plugins";
export async function pluginRunner(name: string, pluginContext: (pluginID: string) => PluginContext, callback: (p: Plugin<any>) => Promise<any>) {
const startTime = Date.now();
for (let i = 0; i < registeredPlugins.length; i++) {
const description = pluginMetadata[i];
const impl = registeredPlugins[i];
const plugin = new impl(pluginContext(description.id));
try {
await callback(plugin);
} catch (e) {
console.warn(`[plugins:${name}] error in ` + description.name, e);
}
}
const runtime = Date.now() - startTime;
// console.log(`[plugins:${name}] ran all plugins in ${runtime.toFixed(1)} ms`);
}

90
app/src/core/search.ts Normal file
View File

@@ -0,0 +1,90 @@
import MiniSearch, { SearchResult } from 'minisearch'
import { ellipsize } from './utils';
import { ChatManager } from '.';
import { Chat, Message } from './chat/types';
export class Search {
private index = new MiniSearch({
fields: ['value'],
storeFields: ['id', 'value'],
});
constructor(private context: ChatManager) {
}
public update(id: string) {
const chat = this.context.get(id);
if (!chat) {
return;
}
const messages = chat.messages.serialize();
const contents = messages.map((m: Message) => m.content).join('\n\n');
const doc = {
id,
value: chat.title ? (chat.title + '\n\n' + contents) : contents,
};
if (!this.index.has(id)) {
this.index.add(doc);
} else {
this.index.replace(doc);
}
}
public delete(id: string) {
if (this.index.has(id)) {
this.index.discard(id);
this.index.vacuum();
}
}
public query(query: string) {
if (!query?.trim().length) {
const searchResults = this.context.all()
.sort((a, b) => b.updated - a.updated)
.slice(0, 10);
const results = this.processSearchResults(searchResults);
return results;
}
let searchResults = this.index.search(query, { fuzzy: 0.2 });
let output = this.processSearchResults(searchResults);
if (!output.length) {
searchResults = this.index.search(query, { prefix: true });
output = this.processSearchResults(searchResults);
}
return output;
}
private processSearchResults(searchResults: SearchResult[] | Chat[]) {
const output: any[] = [];
for (const item of searchResults) {
const chatID = item.id;
let chat = this.context.get(chatID);
if (!chat) {
continue;
}
chat = { ...chat };
let description = chat.messages?.first?.content || '';
description = ellipsize(description, 400);
if (!chat.title) {
chat.title = ellipsize(description, 100);
}
if (!chat.title || !description) {
continue;
}
output.push({
chatID,
title: chat.title,
description,
});
}
return output;
}
}

View File

@@ -0,0 +1,138 @@
declare global {
interface Window {
SpeechRecognition: SpeechRecognition
}
interface SpeechGrammar {
src: string
weight: number
}
const SpeechGrammar: {
prototype: SpeechGrammar
new(): SpeechGrammar
}
interface SpeechGrammarList {
readonly length: number
addFromString(string: string, weight?: number): void
addFromURI(src: string, weight?: number): void
item(index: number): SpeechGrammar
[index: number]: SpeechGrammar
}
const SpeechGrammarList: {
prototype: SpeechGrammarList
new(): SpeechGrammarList
}
interface SpeechRecognitionEventMap {
audioend: Event
audiostart: Event
end: Event
error: SpeechRecognitionError
nomatch: SpeechRecognitionEvent
result: SpeechRecognitionEvent
soundend: Event
soundstart: Event
speechend: Event
speechstart: Event
start: Event
}
interface SpeechRecognition {
continuous: boolean
grammars: SpeechGrammarList
interimResults: boolean
lang: string
maxAlternatives: number
onaudioend: ((this: SpeechRecognition, ev: Event) => any) | null
onaudiostart: ((this: SpeechRecognition, ev: Event) => any) | null
onend: ((this: SpeechRecognition, ev: Event) => any) | null
onerror:
| ((this: SpeechRecognition, ev: SpeechRecognitionError) => any)
| null
onnomatch:
| ((this: SpeechRecognition, ev: SpeechRecognitionEvent) => any)
| null
onresult:
| ((this: SpeechRecognition, ev: SpeechRecognitionEvent) => any)
| null
onsoundend: ((this: SpeechRecognition, ev: Event) => any) | null
onsoundstart: ((this: SpeechRecognition, ev: Event) => any) | null
onspeechend: ((this: SpeechRecognition, ev: Event) => any) | null
onspeechstart: ((this: SpeechRecognition, ev: Event) => any) | null
onstart: ((this: SpeechRecognition, ev: Event) => any) | null
serviceURI: string
abort(): void
start(): void
stop(): void
addEventListener<K extends keyof SpeechRecognitionEventMap>(
type: K,
listener: (
this: SpeechRecognition,
ev: SpeechRecognitionEventMap[K]
) => any,
options?: boolean | AddEventListenerOptions
): void
addEventListener(
type: string,
listener: EventListenerOrEventListenerObject,
options?: boolean | AddEventListenerOptions
): void
removeEventListener<K extends keyof SpeechRecognitionEventMap>(
type: K,
listener: (
this: SpeechRecognition,
ev: SpeechRecognitionEventMap[K]
) => any,
options?: boolean | EventListenerOptions
): void
removeEventListener(
type: string,
listener: EventListenerOrEventListenerObject,
options?: boolean | EventListenerOptions
): void
}
const SpeechRecognition: {
prototype: SpeechRecognition
new(): SpeechRecognition
}
interface SpeechRecognitionError extends Event {
// readonly error: SpeechRecognitionErrorCode;
readonly message: string
}
const SpeechRecognitionError: {
prototype: SpeechRecognitionError
new(): SpeechRecognitionError
}
interface SpeechRecognitionEvent extends Event {
readonly emma: Document | null
readonly interpretation: any
readonly resultIndex: number
readonly results: SpeechRecognitionResultList
}
const SpeechRecognitionEvent: {
prototype: SpeechRecognitionEvent
new(): SpeechRecognitionEvent
}
}
let speechRecognition: SpeechRecognition | null = null;
if (window.SpeechRecognition) {
speechRecognition = new SpeechRecognition()
} else if ((window as any).webkitSpeechRecognition) {
speechRecognition = new (window as any).webkitSpeechRecognition() as SpeechRecognition;
}
const supportsSpeechRecognition = speechRecognition !== null;
export {
speechRecognition,
supportsSpeechRecognition,
}

View File

@@ -0,0 +1,241 @@
const MAX_NUM_THREADS = 128;
type MergeRange = { start: number, end: number };
export class RankMap {
private values = new Map<string, number>();
public static from(texts: string[]) {
const map = new RankMap();
for (let i = 0; i < texts.length; i++) {
map.values.set(texts[i], i);
}
return map;
}
public set(bytes: Uint8Array, rank: number) {
const key = Buffer.from(bytes).toString();
this.values.set(key, rank);
}
public get(bytes: Uint8Array) {
const key = Buffer.from(bytes).toString();
return this.values.get(key);
}
public keys() {
return Array.from(this.values.keys()).map(k => Buffer.from(k));
}
public inverted() {
const inverted = new Map<number, Uint8Array>();
for (const [key, value] of Array.from(this.values.entries())) {
inverted.set(value, new Uint8Array(Buffer.from(key)));
}
return inverted;
}
}
function bytePairMerge(piece: Uint8Array, ranks: RankMap): MergeRange[] {
let parts: MergeRange[] = Array.from({ length: piece.length }, (_, i) => ({ start: i, end: i + 1 }));
while (true) {
if (parts.length === 1) {
break;
}
let minRank: [number, number] | null = null;
for (let i = 0; i < parts.length - 1; i++) {
const rank = ranks.get(piece.slice(parts[i].start, parts[i + 1].end));
if (rank === undefined) {
continue;
}
if (minRank === null || rank < minRank[0]) {
minRank = [rank, i];
}
}
if (minRank !== null) {
const [_, i] = minRank;
parts[i] = { start: parts[i].start, end: parts[i + 1].end };
parts.splice(i + 1, 1);
} else {
break;
}
}
return parts;
}
function bytePairEncode(piece: Uint8Array, ranks: RankMap): number[] {
if (piece.length === 1) {
return [ranks.get(piece)!];
}
return bytePairMerge(piece, ranks).map((p) => ranks.get(piece.slice(p.start, p.end))!);
}
function bytePairSplit(piece: Uint8Array, ranks: RankMap): Uint8Array[] {
if (piece.length === 1) {
return [piece];
}
return bytePairMerge(piece, ranks).map((p) => piece.slice(p.start, p.end));
}
export class CoreBPE {
encoder: RankMap;
specialTokensEncoder: Map<string, number>;
decoder: Map<number, Uint8Array>;
specialTokensDecoder: Map<number, Uint8Array>;
regexTls: RegExp[];
specialRegexTls: RegExp[];
sortedTokenBytes: Uint8Array[];
constructor(
encoder: RankMap,
specialTokensEncoder: Map<string, number>,
regex: RegExp
) {
const specialRegex = new RegExp(
Array.from(specialTokensEncoder.keys())
.map((s) => s.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&"))
.join("|")
);
const decoder: Map<number, Uint8Array> = encoder.inverted();
const specialTokensDecoder: Map<number, Uint8Array> = new Map(
Array.from(specialTokensEncoder.entries()).map(([k, v]) => [v, new Uint8Array(Buffer.from(k))])
);
const sortedTokenBytes: Uint8Array[] = Array.from(encoder.keys());
sortedTokenBytes.sort((a, b) => Buffer.compare(a, b));
this.encoder = encoder;
this.specialTokensEncoder = specialTokensEncoder;
this.decoder = decoder;
this.specialTokensDecoder = specialTokensDecoder;
this.regexTls = Array(MAX_NUM_THREADS).fill(regex);
this.specialRegexTls = Array(MAX_NUM_THREADS).fill(specialRegex);
this.sortedTokenBytes = sortedTokenBytes;
}
private _getTlRegex(): RegExp {
return this.regexTls[Math.floor(Math.random() * MAX_NUM_THREADS)];
}
private _getTlSpecialRegex(): RegExp {
return this.specialRegexTls[Math.floor(Math.random() * MAX_NUM_THREADS)];
}
private _decodeNative(tokens: number[]): Uint8Array {
const ret: number[] = [];
for (const token of tokens) {
const tokenBytes = this.decoder.get(token) || this.specialTokensDecoder.get(token)!;
ret.push(...Array.from(tokenBytes));
}
return new Uint8Array(ret);
}
private _encodeOrdinaryNative(text: string): number[] {
const regex = this._getTlRegex();
const ret: number[] = [];
let match: RegExpExecArray | null;
while ((match = regex.exec(text)) !== null) {
const piece = new Uint8Array(Buffer.from(match[0]));
const token = this.encoder.get(piece);
if (token !== undefined) {
ret.push(token);
continue;
}
ret.push(...bytePairEncode(piece, this.encoder));
}
return ret;
}
private _encodeNative(text: string, allowedSpecial: Set<string>): [number[], number] {
const specialRegex = this._getTlSpecialRegex();
const regex = this._getTlRegex();
const ret: number[] = [];
let start = 0;
let lastPieceTokenLen = 0;
while (true) {
let nextSpecial: RegExpExecArray | null;
let startFind = start;
while (true) {
nextSpecial = specialRegex.exec(text.slice(startFind));
if (nextSpecial === null || allowedSpecial.has(nextSpecial[0])) {
break;
}
startFind = nextSpecial.index + 1;
}
const end = nextSpecial === null ? text.length : nextSpecial.index;
let match: RegExpExecArray | null;
while ((match = regex.exec(text.slice(start, end))) !== null) {
const piece = new Uint8Array(Buffer.from(match[0]));
const token = this.encoder.get(piece);
if (token !== undefined) {
lastPieceTokenLen = 1;
ret.push(token);
continue;
}
const tokens = bytePairEncode(piece, this.encoder);
lastPieceTokenLen = tokens.length;
ret.push(...tokens);
}
if (nextSpecial === null) {
break;
}
const piece = nextSpecial[0];
const token = this.specialTokensEncoder.get(piece)!;
ret.push(token);
start = nextSpecial.index + piece.length;
lastPieceTokenLen = 0;
}
return [ret, lastPieceTokenLen];
}
encodeOrdinary(text: string): number[] {
return this._encodeOrdinaryNative(text);
}
encode(text: string, allowedSpecial: Set<string>): number[] {
return this._encodeNative(text, allowedSpecial)[0];
}
encodeWithUnstable(text: string, allowedSpecial: Set<string>): [number[], Set<number[]>] {
throw new Error("Not implemented");
}
encodeSingleToken(piece: Uint8Array): number {
const token = this.encoder.get(piece);
if (token !== undefined) {
return token;
}
const pieceStr = Buffer.from(piece).toString("utf-8");
if (this.specialTokensEncoder.has(pieceStr)) {
return this.specialTokensEncoder.get(pieceStr)!;
}
throw new Error("Key not found");
}
encodeSinglePiece(piece: Uint8Array): number[] {
const token = this.encoder.get(piece);
if (token !== undefined) {
return [token];
}
return bytePairEncode(piece, this.encoder);
}
decodeBytes(tokens: number[]): Uint8Array {
return this._decodeNative(tokens);
}
decodeSingleTokenBytes(token: number): Uint8Array {
const bytes = this.decoder.get(token) || this.specialTokensDecoder.get(token);
if (bytes !== undefined) {
return bytes;
}
throw new Error("Key not found");
}
tokenByteValues(): Uint8Array[] {
return this.sortedTokenBytes;
}
}

View File

@@ -0,0 +1,141 @@
import { OpenAIMessage } from '../chat/types';
import * as tokenizer from '.';
export interface ChatHistoryTrimmerOptions {
maxTokens: number,
nMostRecentMessages?: number,
preserveSystemPrompt: boolean,
preserveFirstUserMessage: boolean,
}
export class ChatHistoryTrimmer {
private output: OpenAIMessage[] = [];
constructor(private messages: OpenAIMessage[],
private readonly options: ChatHistoryTrimmerOptions) {
}
private countExcessTokens() {
return Math.max(0, tokenizer.countTokensForMessages(this.output) - this.options.maxTokens);
}
public process() {
this.output = this.messages.map(m => ({ ...m }));
if (this.options.nMostRecentMessages) {
this.output = this.removeUnwantedMessagesStrategy();
}
let excessTokens = this.countExcessTokens();
if (excessTokens === 0) {
return this.output;
}
this.output = this.removeMessagesStrategy();
excessTokens = this.countExcessTokens();
if (excessTokens === 0) {
return this.output;
}
this.output = this.trimMessagesStrategy(excessTokens);
excessTokens = this.countExcessTokens();
if (excessTokens === 0) {
return this.output;
}
const systemPrompt = this.messages.find(m => m.role === 'system');
const firstUserMessage = this.messages.find(m => m.role === 'user');
const last = this.messages[this.messages.length - 1];
this.output = [
tokenizer.truncateMessage(systemPrompt!, 100),
];
if (firstUserMessage === last) {
this.output.push(tokenizer.truncateMessage(firstUserMessage, this.options.maxTokens - 100));
} else {
this.output.push(tokenizer.truncateMessage(firstUserMessage!, 100));
this.output.push(tokenizer.truncateMessage(last, this.options.maxTokens - 200));
}
excessTokens = this.countExcessTokens();
if (excessTokens === 0) {
return this.output;
}
this.output = [
tokenizer.truncateMessage(last, this.options.maxTokens),
];
return this.output;
}
private removeUnwantedMessagesStrategy() {
const systemPromptIndex = this.messages.findIndex(m => m.role === 'system');
const firstUserMessageIndex = this.messages.findIndex(m => m.role === 'user');
const keepFromIndex = this.messages.length - (this.options.nMostRecentMessages || 1);
const output: OpenAIMessage[] = [];
for (let i = 0; i < this.output.length; i++) {
if (i === systemPromptIndex && this.options.preserveSystemPrompt) {
output.push(this.output[i]);
} else if (i === firstUserMessageIndex && this.options.preserveFirstUserMessage) {
output.push(this.output[i]);
} else if (i >= keepFromIndex) {
output.push(this.output[i]);
}
}
return output;
}
private removeMessagesStrategy() {
const systemPromptIndex = this.messages.findIndex(m => m.role === 'system');
const firstUserMessageIndex = this.messages.findIndex(m => m.role === 'user');
const lastMessageIndex = this.messages.length - 1;
const output: OpenAIMessage[] = [...this.output];
for (let i = 0; i < this.output.length && tokenizer.countTokensForMessages(output) > this.options.maxTokens; i++) {
if (i === lastMessageIndex) {
continue;
}
if (i === systemPromptIndex || this.options.preserveSystemPrompt) {
continue;
}
if (i === firstUserMessageIndex || this.options.preserveFirstUserMessage) {
continue;
}
output[i].content = '';
}
return output.filter(m => m.content.length > 0);
}
private trimMessagesStrategy(excessTokens: number) {
const systemPromptIndex = this.output.findIndex(m => m.role === 'system');
const firstUserMessageIndex = this.output.findIndex(m => m.role === 'user');
const lastMessageIndex = this.output.length - 1;
const output: OpenAIMessage[] = [...this.output];
const truncateLength = Math.floor(excessTokens / this.output.length);
for (let i = 0; i < this.output.length && tokenizer.countTokensForMessages(output) > this.options.maxTokens; i++) {
if (i === lastMessageIndex) {
continue;
}
if (i === systemPromptIndex && this.options.preserveSystemPrompt) {
continue;
}
if (i === firstUserMessageIndex && this.options.preserveFirstUserMessage) {
continue;
}
output[i] = tokenizer.truncateMessage(output[i], truncateLength);
}
return output.filter(m => m.content.length > 0);
}
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,65 @@
import { OpenAIMessage } from "../chat/types";
import { CoreBPE, RankMap } from "./bpe";
import ranks from './cl100k_base.json';
const special_tokens: any = {
"<|endoftext|>": 100257,
"<|fim_prefix|>": 100258,
"<|fim_middle|>": 100259,
"<|fim_suffix|>": 100260,
"<|endofprompt|>": 100276,
};
const special_tokens_map = new Map<string, number>();
for (const text of Object.keys(special_tokens)) {
special_tokens_map.set(text, special_tokens_map[text]);
}
const pattern = /('s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/giu;
const tokenizer = new CoreBPE(RankMap.from(ranks), special_tokens_map, pattern);
const overheadTokens = {
perMessage: 5,
perPrompt: 2,
}
const tokenCache = new Map<string, number>();
export function countTokensForText(text: string) {
const cacheKey = text;
if (tokenCache.has(cacheKey)) {
return tokenCache.get(cacheKey)!;
}
let t1 = Date.now();
const tokens = tokenizer.encodeOrdinary(text).length;
tokenCache.set(cacheKey, tokens);
return tokens;
}
export function countTokensForMessage(message: OpenAIMessage) {
return countTokensForText(message.content) + overheadTokens.perMessage;
}
export function countTokensForMessages(messages: OpenAIMessage[]) {
let tokens = overheadTokens.perPrompt;
for (const m of messages) {
tokens += countTokensForMessage(m);
}
return tokens;
}
export function truncateText(text: string, tokens: number) {
const encoded = tokenizer.encodeOrdinary(text);
const decoded = tokenizer.decodeBytes(encoded.slice(0, Math.max(0, tokens)));
return new TextDecoder().decode(decoded);
}
export function truncateMessage(message: OpenAIMessage, tokens: number) {
const encoded = tokenizer.encodeOrdinary(message.content);
const decoded = tokenizer.decodeBytes(encoded.slice(0, Math.max(0, tokens - overheadTokens.perMessage)));
return {
role: message.role,
content: new TextDecoder().decode(decoded),
};
}

View File

@@ -0,0 +1,16 @@
import * as methods from ".";
import { OpenAIMessage } from "../chat/types";
import { ChatHistoryTrimmer, ChatHistoryTrimmerOptions } from "./chat-history-trimmer";
export function runChatTrimmer(messages: OpenAIMessage[], options: ChatHistoryTrimmerOptions) {
const trimmer = new ChatHistoryTrimmer(messages, options);
return trimmer.process();
}
export function countTokensForText(text: string) {
return methods.countTokensForText(text);
}
export function countTokensForMessages(messages: OpenAIMessage[]) {
return methods.countTokensForMessages(messages);
}

View File

@@ -0,0 +1,28 @@
import { OpenAIMessage } from "../chat/types";
import type { ChatHistoryTrimmerOptions } from "./chat-history-trimmer";
// @ts-ignore
import tokenizer from 'workerize-loader!./worker';
let worker: any;
async function getWorker() {
if (!worker) {
worker = await tokenizer();
}
return worker;
}
export async function runChatTrimmer(messages: OpenAIMessage[], options: ChatHistoryTrimmerOptions): Promise<OpenAIMessage[]> {
const worker = await getWorker();
return worker.runChatTrimmer(messages, options);
}
export async function countTokens(messages: OpenAIMessage[]) {
const worker = await getWorker();
return await worker.countTokensForMessages(messages);
}
// preload the worker
getWorker().then(w => {
(window as any).worker = w;
})

View File

@@ -0,0 +1,286 @@
import { AbstractTTSPlayer, TTSPlayerState } from './types';
import { cloneArrayBuffer, md5, sleep } from '../utils';
import { AsyncLoop } from "../utils/async-loop";
import * as idb from '../utils/idb';
import TTSPlugin from './tts-plugin';
export let audioContext = new AudioContext();
export let audioContextInUse = false;
export function resetAudioContext() {
if (audioContextInUse) {
const previousAudioContext = audioContext;
audioContext = new AudioContext();
audioContextInUse = false;
setTimeout(() => previousAudioContext.close(), 0);
setTimeout(() => audioContext.suspend(), 100);
}
audioContext.resume();
}
const cache = new Map<string, ArrayBuffer>();
async function getAudioFile(plugin: TTSPlugin<any>, text: string) {
const voice = await plugin.getCurrentVoice();
const hash = await md5(text);
const cacheKey = `audio:${voice?.service}:${voice?.id}:${hash}`;
let buffer = cache.get(cacheKey);
if (!buffer) {
buffer = await idb.get(cacheKey);
}
if (!buffer) {
try {
const result = await plugin.speakToBuffer(text);
if (result) {
buffer = result;
cache.set(cacheKey, cloneArrayBuffer(buffer));
idb.set(cacheKey, cloneArrayBuffer(buffer));
return buffer;
}
} catch (e) {
console.error(e);
}
}
return buffer || null;
}
export default class ExternalTTSAudioFilePlayer extends AbstractTTSPlayer {
private playing = true;
private ended = false;
private requestedSentenceIndex = 0; // sentence index requested by user
private currentSentenceIndex = 0;
private startTime = 0;
private audioArrayBuffers: ArrayBuffer[] = [];
private downloadLoop: AsyncLoop;
private schedulerLoop: AsyncLoop;
private sourceNodes: AudioBufferSourceNode[] = [];
private durations: number[] = [];
private duration = 0;
private destroyed = false;
constructor(private plugin: TTSPlugin) {
super();
this.downloadLoop = new AsyncLoop(this.download, 1000);
this.schedulerLoop = new AsyncLoop(this.schedule, 100);
audioContext.resume();
requestAnimationFrame(async () => {
audioContext.suspend();
this.downloadLoop.start();
this.schedulerLoop.start();
});
(window as any).player = this;
}
private download = async () => {
const sentences = [...this.sentences];
if (!this.complete) {
sentences.pop();
}
const maxSentencesToDownload = this.sourceNodes[this.currentSentenceIndex] ? 2 : 1;
const sentencesToDownload: number[] = [];
for (let i = 0; i < sentences.length; i++) {
if (sentencesToDownload.length >= maxSentencesToDownload) {
break;
}
if (!this.audioArrayBuffers[i]) {
sentencesToDownload.push(i);
}
}
const files = await Promise.all(sentencesToDownload.map(async sentenceIndex => {
try {
const text = sentences[sentenceIndex];
return await getAudioFile(this.plugin, text);
} catch (e) {
console.warn('error downloading tts audio', e);
}
}));
for (let i = 0; i < sentencesToDownload.length; i++) {
const sentenceIndex = sentencesToDownload[i];
const file = files[i];
if (file) {
this.audioArrayBuffers[sentenceIndex] = file;
} else {
await sleep(5000); // back off
}
}
this.emit('state', this.getState());
}
private schedule = async () => {
let time = this.startTime;
if (this.playing && this.sourceNodes[this.currentSentenceIndex] && audioContext.state === 'suspended') {
try {
await this.resumeAudioContext();
} catch (e: any) {
console.error(e);
}
}
try {
for (let i = this.requestedSentenceIndex; i < this.sentences.length; i++) {
if (this.destroyed) {
return;
}
const audioArrayBuffer = this.audioArrayBuffers[i];
if (!audioArrayBuffer) {
break;
}
if (!this.sourceNodes[i]) {
const audioBuffer = await audioContext.decodeAudioData(cloneArrayBuffer(audioArrayBuffer));
this.durations[i] = audioBuffer.duration;
const sourceNode = audioContext.createBufferSource();
sourceNode.buffer = audioBuffer;
if (i === this.requestedSentenceIndex) {
this.startTime = audioContext.currentTime;
time = this.startTime;
}
sourceNode.start(time);
this.duration = time + this.durations[i] - this.startTime;
audioContextInUse = true;
this.sourceNodes[i] = sourceNode;
sourceNode.connect(audioContext.destination);
if (this.playing) {
await this.resumeAudioContext();
}
sourceNode.onended = async () => {
if (this.destroyed) {
return;
}
this.currentSentenceIndex = i + 1;
this.ended = this.complete && this.currentSentenceIndex === this.sentences.length;
const isBuffering = !this.ended && !this.sourceNodes[this.currentSentenceIndex];
if (this.ended || isBuffering) {
await this.suspendAudioContext();
}
if (this.ended) {
this.playing = false;
}
this.emit('state', this.getState());
};
this.emit('state', this.getState());
}
time += this.durations[i] + 0.25;
}
} catch (e: any) {
console.error(e);
}
}
private async resumeAudioContext() {
try {
audioContext.resume();
await sleep(10);
} catch (e) {
console.warn('error resuming audio context', e);
}
}
private async suspendAudioContext() {
try {
await audioContext.suspend();
} catch (e) {
console.warn('error suspending audio context', e);
}
}
public getState(): TTSPlayerState {
return {
playing: this.playing,
ended: this.ended,
buffering: this.playing && !this.ended && !this.sourceNodes[this.currentSentenceIndex],
duration: this.duration,
length: this.sentences.length,
ready: this.audioArrayBuffers.filter(Boolean).length,
index: this.currentSentenceIndex,
downloadable: this.complete && this.sourceNodes.length === this.sentences.length,
} as any;
}
public async pause() {
this.playing = false;
await this.suspendAudioContext();
this.emit('state', this.getState());
}
public async play(index?: number) {
this.playing = true;
if (typeof index === 'number') {
this.requestedSentenceIndex = index;
this.currentSentenceIndex = index;
resetAudioContext();
if (this.sourceNodes.length) {
resetAudioContext();
this.sourceNodes = [];
this.durations = [];
this.duration = 0;
this.ended = false;
}
} else if (this.ended) {
await this.play(0);
} else if (audioContext.currentTime < this.duration) {
await this.resumeAudioContext();
} else {
await this.play(Math.max(0, this.sourceNodes.length - 1));
}
this.emit('state', this.getState());
}
public destroy() {
this.playing = false;
this.destroyed = true;
this.downloadLoop.cancelled = true;
this.schedulerLoop.cancelled = true;
resetAudioContext();
this.sourceNodes = [];
this.durations = [];
this.duration = 0;
this.removeAllListeners();
}
}

View File

@@ -0,0 +1,120 @@
import DirectTTSPlugin from "./direct-tts-plugin";
import { AsyncLoop } from "../utils/async-loop";
import { AbstractTTSPlayer } from "./types";
import WebSpeechPlugin from "../../tts-plugins/web-speech";
export default class DirectTTSPlayer extends AbstractTTSPlayer {
playing = false;
ended = false;
private loop: AsyncLoop;
private currentIndex = 0;
private currentPlaybackIndex = 0;
private promises: any[] = [];
constructor(private plugin: WebSpeechPlugin) {
super();
console.log('tts init, directttsplayer');
this.emit('state', this.getState());
this.loop = new AsyncLoop(() => this.tick(), 100);
this.loop.start();
}
private async tick() {
if (!this.playing) {
return;
}
const sentences = [...this.sentences];
if (!this.complete) {
sentences.pop();
}
if (this.currentPlaybackIndex >= sentences.length) {
if (this.complete) {
console.log(`tts finished 1, current index: ${this.currentPlaybackIndex}, sentences length: ${sentences.length}`);
try {
await Promise.all(this.promises);
} catch (e) {
console.error('an error occured while reading text aloud', e);
}
console.log(`tts finished 2, current index: ${this.currentPlaybackIndex}, sentences length: ${sentences.length}`);
this.playing = false;
this.ended = true;
this.currentIndex = 0;
this.currentPlaybackIndex = 0;
this.promises = [];
this.emit('state', this.getState());
return;
}
}
if (this.currentIndex >= sentences.length) {
return;
}
this.ended = false;
try {
this.emit('state', this.getState());
const text = sentences[this.currentIndex];
console.log(`tts speaking`, text);
const p = this.plugin.speak(text);
p.then(() => {
this.currentPlaybackIndex = this.currentIndex + 1;
});
this.promises.push(p);
this.currentIndex += 1;
} catch (e) {
console.error('an error occured while reading text aloud', e);
}
}
async play(index?: number): Promise<any> {
if (this.playing) {
await this.plugin.stop();
this.promises = [];
}
this.playing = true;
this.ended = false;
if (typeof index === 'number') {
this.currentIndex = index;
this.currentPlaybackIndex = index;
}
await this.plugin.resume();
this.emit('state', this.getState());
}
async pause(): Promise<any> {
await this.plugin.pause();
this.playing = false;
this.emit('state', this.getState());
}
getState() {
return {
playing: this.playing,
ended: this.ended,
buffering: this.playing && !this.plugin.isSpeaking(),
index: this.currentPlaybackIndex,
length: this.sentences.length,
downloadable: false,
}
}
async destroy() {
if (this.playing) {
this.plugin.stop();
}
this.loop.cancelled = true;
this.playing = false;
this.removeAllListeners();
}
}

View File

@@ -0,0 +1,10 @@
import { Voice } from "./types";
import TTSPlugin from "./tts-plugin";
export default class DirectTTSPlugin<T=any> extends TTSPlugin<T> {
async speak(text: string, voice?: Voice) {
}
async stop() {
}
}

View File

@@ -0,0 +1,16 @@
import Plugin from "../plugins";
import { Voice } from "../tts/types";
export default class TTSPlugin<T=any> extends Plugin<T> {
async getVoices(): Promise<Voice[]> {
return [];
}
async getCurrentVoice(): Promise<Voice> {
throw new Error("not implemented");
}
async speakToBuffer(text: string, voice?: Voice): Promise<ArrayBuffer | null | undefined> {
throw new Error("not implemented");
}
}

48
app/src/core/tts/types.ts Normal file
View File

@@ -0,0 +1,48 @@
import EventEmitter from "events";
import { split } from "sentence-splitter";
export interface TTSPlayerState {
playing: boolean;
ended: boolean;
buffering: boolean;
duration?: number;
index: number;
length: number;
ready?: number;
downloadable: boolean;
}
export abstract class AbstractTTSPlayer extends EventEmitter {
private lines: string[] = [];
protected sentences: string[] = [];
protected complete = false;
abstract play(index?: number): Promise<any>;
abstract pause(): Promise<any>;
abstract getState(): TTSPlayerState;
abstract destroy(): any;
public setText(lines: string[], complete: boolean) {
this.lines = lines;
this.complete = complete;
this.updateSentences();
}
private updateSentences() {
const output: string[] = [];
for (const line of this.lines) {
const sentences = split(line);
for (const sentence of sentences) {
output.push(sentence.raw.trim());
}
}
this.sentences = output.filter(s => s.length > 0);
}
}
export interface Voice {
service: string;
id: string;
name?: string;
sampleAudioURL?: string;
}

View File

@@ -0,0 +1,162 @@
import { createContext, useCallback, useContext, useEffect, useMemo, useRef, useState } from "react";
import { useAppContext } from "../context";
import { ttsPlugins } from "../plugins/metadata";
import Plugin from "../plugins";
import { AbstractTTSPlayer, TTSPlayerState, Voice } from "./types";
import { createBasicPluginContext } from "../plugins/plugin-context";
import DirectTTSPlayer from "./direct-tts-player";
import DirectTTSPlugin from "./direct-tts-plugin";
import TTSPlugin from "./tts-plugin";
import ExternalTTSAudioFilePlayer from "./audio-file-player";
import { split } from "sentence-splitter";
import { useOption } from "../options/use-option";
function extractTextSegments(element: HTMLElement) {
const selector = 'p, li, th, td, blockquote, pre code, h1, h2, h3, h3, h5, h6';
const nodes = Array.from(element.querySelectorAll(selector) || []);
const lines: string[] = [];
const blocks = nodes.filter(node => !node.parentElement?.closest(selector) && node.textContent);
for (const block of blocks) {
const tagName = block.tagName.toLowerCase();
if (tagName === 'p' || tagName === 'li' || tagName === 'blockquote') {
const sentences = split(block.textContent!);
for (const sentence of sentences) {
lines.push(sentence.raw.trim());
}
} else {
lines.push(block.textContent!.trim());
}
}
return lines.filter(line => line.length);
}
interface ITTSContext {
key: string | null;
voice: Voice | null;
autoplayEnabled: boolean;
state?: TTSPlayerState;
play(index?: number): void;
pause(): void;
cancel(): void;
setSourceElement(key: string, element: HTMLElement | null): void;
setComplete(complete: boolean): void;
}
export function useTTSPlayerState(): ITTSContext {
const context = useAppContext();
const [ttsPluginID] = useOption<string>('tts', 'service');
const [autoplayEnabled] = useOption<boolean>('tts', 'autoplay');
const [voiceID] = useOption<string>(ttsPluginID, 'voice');
const voice = useMemo(() => ({
service: ttsPluginID,
id: voiceID,
}), [ttsPluginID, voiceID]);
const ttsPluginImpl = useMemo(() => {
const ttsPluginIndex = ttsPlugins.findIndex(p => new p().describe().id === ttsPluginID) || 0;
return ttsPlugins[ttsPluginIndex];
}, [ttsPluginID]);
const plugin = useRef<Plugin|null>(null);
const player = useRef<AbstractTTSPlayer|null>(null);
const elementRef = useRef<HTMLElement|null>(null);
const [key, setKey] = useState<string|null>(null);
const [state, setState] = useState(() => player.current?.getState());
const [complete, setComplete] = useState(false);
const timer = useRef<any>();
const setSourceElement = useCallback((newKey: string | null, element: HTMLElement | null) => {
elementRef.current = element;
if (key !== newKey || !element) {
plugin.current = null;
player.current?.destroy();
player.current = null;
}
setKey(newKey);
if (element) {
if (!plugin.current) {
const pluginContext = createBasicPluginContext(ttsPluginID, context.chat.options, context.id, context.currentChat.chat)
plugin.current = new ttsPluginImpl(pluginContext);
}
if (!player.current) {
if (plugin.current instanceof DirectTTSPlugin) {
player.current = new DirectTTSPlayer(plugin.current as any);
} else if (plugin.current instanceof TTSPlugin) {
player.current = new ExternalTTSAudioFilePlayer(plugin.current);
}
player.current!.on('state', setState);
}
} else {
setState(undefined);
}
}, [ttsPluginID, context, complete, key]);
useEffect(() => {
setSourceElement(null, null);
}, [ttsPluginID, voiceID]);
useEffect(() => {
clearInterval(timer.current);
const update = () => {
if (!player.current || !elementRef.current) {
return;
}
player.current.setText(extractTextSegments(elementRef.current), complete);
};
update();
if (!complete) {
timer.current = setInterval(update, 1000);
}
}, [key, complete]);
return {
key,
voice: voiceID ? voice : null,
autoplayEnabled,
state: !state?.ended ? state : undefined,
play(index?: number) {
player.current?.play(index);
},
pause() {
player.current?.pause();
},
cancel() {
setSourceElement(null, null);
},
setSourceElement,
setComplete,
}
}
const TTSContext = createContext<ITTSContext>({
key: null,
voice: null,
autoplayEnabled: false,
play() {},
pause() {},
cancel() {},
setSourceElement() {},
setComplete() {},
});
export function useTTS() {
return useContext(TTSContext);
}
export function TTSContextProvider(props: { children: React.ReactNode }) {
const context = useTTSPlayerState();
return <TTSContext.Provider value={context}>{props.children}</TTSContext.Provider>;
}

View File

@@ -0,0 +1,46 @@
import { sleep } from '.';
/**
* AsyncLoop class provides a mechanism to execute a given function
* asynchronously in a loop with a specified delay between each execution.
* Unlike setInterval, it ensures that each iteration finishes before
* starting the next one.
*/
export class AsyncLoop {
public cancelled = false;
/**
* Creates a new instance of the AsyncLoop class.
* @param {Function} handler - The function to be executed in the loop.
* @param {number} pauseBetween - The delay (in milliseconds) between each execution of the handler. Default is 1000 ms.
*/
constructor(private handler: any, private pauseBetween: number = 1000) {
}
/**
* Starts the asynchronous loop by calling the loop() method.
*/
public start() {
this.loop().then(() => { });
}
/**
* The main loop function that executes the given handler function
* while the loop is not cancelled. It catches any errors thrown by
* the handler function and logs them to the console.
* @private
* @returns {Promise<void>} A Promise that resolves when the loop is cancelled.
*/
private async loop() {
while (!this.cancelled) {
try {
await this.handler();
} catch (e) {
console.error(e);
}
await sleep(this.pauseBetween);
}
}
}

View File

@@ -0,0 +1,91 @@
import EventEmitter from 'events';
export interface EventEmitterAsyncIteratorOutput<T> {
eventName: string;
value: T;
}
/**
* The EventEmitterAsyncIterator class provides a way to create an async iterator
* that listens to multiple events from an EventEmitter instance, and yields
* the emitted event name and value as an EventEmitterAsyncIteratorOutput object.
*
* This class implements the AsyncIterableIterator interface, which allows it
* to be used in for-await-of loops and other asynchronous iteration contexts.
*
* @typeparam T - The type of values emitted by the events.
*
* @example
* const eventEmitter = new EventEmitter();
* const asyncIterator = new EventEmitterAsyncIterator(eventEmitter, ['event1', 'event2']);
*
* for await (const event of asyncIterator) {
* console.log(`Received event: ${event.eventName} with value: ${event.value}`);
* }
*/
export class EventEmitterAsyncIterator<T> implements AsyncIterableIterator<EventEmitterAsyncIteratorOutput<T>> {
private eventQueue: EventEmitterAsyncIteratorOutput<T>[] = [];
private resolveQueue: ((value: IteratorResult<EventEmitterAsyncIteratorOutput<T>>) => void)[] = [];
/**
* Constructor takes an EventEmitter instance and an array of event names to listen to.
* For each event name, it binds the pushEvent method with the eventName, which
* will be called when the event is emitted.
*
* @param eventEmitter - The EventEmitter instance to listen to events from.
* @param eventNames - An array of event names to listen to.
*/
constructor(private eventEmitter: EventEmitter, eventNames: string[]) {
for (const eventName of eventNames) {
this.eventEmitter.on(eventName, this.pushEvent.bind(this, eventName));
}
}
/**
* The next method is called when the iterator is requested to return the next value.
* If there is an event in the eventQueue, it will return the next event from the queue.
* If the eventQueue is empty, it will return a Promise that resolves when a new event is received.
*
* @returns A Promise that resolves with the next event or waits for a new event if the queue is empty.
*/
async next(): Promise<IteratorResult<EventEmitterAsyncIteratorOutput<T>>> {
if (this.eventQueue.length > 0) {
const value = this.eventQueue.shift();
return { value: value as EventEmitterAsyncIteratorOutput<T>, done: false };
} else {
return new Promise<IteratorResult<EventEmitterAsyncIteratorOutput<T>>>(resolve => {
this.resolveQueue.push(value => {
resolve(value);
});
});
}
}
/**
* The pushEvent method is called when an event is emitted from the EventEmitter.
* If there is a pending Promise in the resolveQueue, it resolves the Promise with the new event.
* If there is no pending Promise, it adds the event to the eventQueue.
*
* @param eventName - The name of the emitted event.
* @param value - The value emitted with the event.
*/
private pushEvent(eventName: string, value: T): void {
const output: EventEmitterAsyncIteratorOutput<T> = {
eventName,
value,
};
if (this.resolveQueue.length > 0) {
const resolve = this.resolveQueue.shift();
if (resolve) {
resolve({ value: output, done: false });
}
} else {
this.eventQueue.push(output);
}
}
[Symbol.asyncIterator](): AsyncIterableIterator<EventEmitterAsyncIteratorOutput<T>> {
return this;
}
}

120
app/src/core/utils/idb.ts Normal file
View File

@@ -0,0 +1,120 @@
/*
* This file provides a wrapper for IndexedDB (IDB), specifically designed to handle cases
* where IDB is unavailable, such as when the user is in private browsing mode. The wrapper
* uses the 'idb-keyval' library for interacting with IDB and maintains an in-memory cache
* as a fallback mechanism when IDB is not accessible.
*
* The module exports various functions for working with key-value pairs, such as getting,
* setting, deleting, and retrieving keys and entries. These functions first attempt to
* interact with IDB, and if it fails (e.g., due to unavailability), they fall back to
* the in-memory cache. This ensures that the application can continue to function even
* in cases where IDB is not supported or disabled.
*
* The wrapper performs an initial test to check whether IDB is supported in the current
* environment. If not, it sets the 'supported' flag to false, and all subsequent operations
* will rely on the in-memory cache.
*/
import * as idb from 'idb-keyval';
let supported = true;
const inMemoryCache = new Map<string, any>();
const testDB = indexedDB.open('idb-test');
testDB.onerror = () => {
supported = false;
};
export async function keys() {
if (supported) {
try {
const keys = await idb.keys();
return Array.from(keys).map(k => k.toString());
} catch (e) {}
}
return Array.from(inMemoryCache.keys());
}
export async function set(key, value) {
// all values are saved in memory in case IDB fails later, but only retrieved after IDB fails.
inMemoryCache.set(key, value);
if (supported) {
try {
await idb.set(key, value);
return;
} catch (e) {}
}
}
export async function get(key) {
if (supported) {
try {
return await idb.get(key);
}
catch (e) {}
}
return inMemoryCache.get(key);
}
export async function getMany(keys) {
if (supported) {
try {
return await idb.getMany(keys);
}
catch (e) {}
}
const values: any[] = [];
for (const key of keys) {
values.push(inMemoryCache.get(key));
}
return values;
}
export async function setMany(items: [string, any][]) {
// all values are saved in memory in case IDB fails later, but only retrieved after IDB fails.
for (const [key, value] of items) {
inMemoryCache.set(key, value);
}
if (supported) {
try {
await idb.setMany(items);
return;
} catch (e) {}
}
}
export async function entries() {
if (supported) {
try {
const entries = await idb.entries();
return Array.from(entries)
.map(([key, value]) => [key.toString(), value]);
} catch (e) {}
}
return Array.from(inMemoryCache.entries());
}
export async function del(key: string) {
// all values are saved in memory in case IDB fails later, but only retrieved after IDB fails.
inMemoryCache.delete(key);
if (supported) {
try {
await idb.del(key);
return;
} catch (e) {}
}
}
export async function delMany(keys: string[]) {
// all values are saved in memory in case IDB fails later, but only retrieved after IDB fails.
for (const key of keys) {
inMemoryCache.delete(key);
}
if (supported) {
try {
await idb.delMany(keys);
return;
} catch (e) {}
}
}

101
app/src/core/utils/index.ts Normal file
View File

@@ -0,0 +1,101 @@
import * as hashes from 'jshashes';
/**
* Pauses the execution of the function for a specified duration.
*
* @export
* @param {number} ms - The duration (in milliseconds) to pause the execution.
* @returns {Promise} A Promise that resolves after the specified duration.
*/
export function sleep(ms: number): Promise<any> {
return new Promise(resolve => setTimeout(resolve, ms));
}
/**
* Truncates a given string to a specified length and appends ellipsis (...) if needed.
*
* @export
* @param {string} text - The input string to be ellipsized.
* @param {number} maxLength - The maximum length of the output string (including the ellipsis).
* @returns {string} The ellipsized string.
*/
export function ellipsize(text: string, maxLength: number): string {
if (text.length > maxLength) {
return text.substring(0, maxLength) + '...';
}
return text;
}
/**
* Creates a deep clone of the given ArrayBuffer.
*
* @export
* @param {ArrayBuffer} buffer - The ArrayBuffer to clone.
* @returns {ArrayBuffer} A new ArrayBuffer containing the same binary data as the input buffer.
*/
export function cloneArrayBuffer(buffer: ArrayBuffer): ArrayBuffer {
const newBuffer = new ArrayBuffer(buffer.byteLength);
new Uint8Array(newBuffer).set(new Uint8Array(buffer));
return newBuffer;
}
/**
* Shares the specified text using the Web Share API if available in the user's browser.
*
* @function
* @async
* @param {string} text - The text to be shared.
* @example
* share("Hello, World!");
*/
export async function share(text: string) {
if (navigator.share) {
await navigator.share({
text,
});
}
}
/*
Hashing
*/
const hasher = new hashes.MD5();
const hashCache = new Map<string, string>();
export async function md5(data: string): Promise<string> {
if (!hashCache.has(data)) {
const hashHex = hasher.hex(data);
hashCache.set(data, hashHex);
}
return hashCache.get(data)!;
}
/*
Rate limiting
*/
export function getRateLimitResetTimeFromResponse(response: Response): number {
const now = Date.now();
const fallbackValue = now + 20*1000;
const maxValue = now + 2*60*1000;
const rateLimitReset = response.headers.get("x-ratelimit-reset");
if (!rateLimitReset) {
return fallbackValue;
}
let resetTime = parseInt(rateLimitReset, 10);
if (isNaN(resetTime)) {
return fallbackValue;
}
resetTime *= 1000;
if (resetTime > fallbackValue) {
return maxValue;
}
return resetTime;
}

255
app/src/core/utils/sse.ts Normal file
View File

@@ -0,0 +1,255 @@
/**
* This class is an implementation of Server-Side Events (SSE) that allows sending POST request bodies.
*
* It's an adapted version of an open-source implementation, and it's designed to support streaming
* completions for OpenAI requests
*
* Original Copyright:
* Copyright (C) 2016 Maxime Petazzoni <maxime.petazzoni@bulix.org>.
* All rights reserved.
*/
export default class SSE {
// Constants representing the ready state of the SSE connection
public INITIALIZING = -1;
public CONNECTING = 0;
public OPEN = 1;
public CLOSED = 2;
// Connection settings
private headers = this.options.headers || {};
private payload = this.options.payload !== undefined ? this.options.payload : '';
private method = this.options.method ? this.options.method : (this.payload ? 'POST' : 'GET');
private withCredentials = !!this.options.withCredentials;
// Internal properties
private FIELD_SEPARATOR = ':';
private listeners: any = {};
private xhr: any = null;
private readyState = this.INITIALIZING;
private progress = 0;
private chunk = '';
public constructor(public url: string, public options: any) { }
/**
* Starts streaming data from the SSE connection.
*/
public stream = () => {
this.setReadyState(this.CONNECTING);
this.xhr = new XMLHttpRequest();
this.xhr.addEventListener('progress', this.onStreamProgress);
this.xhr.addEventListener('load', this.onStreamLoaded);
this.xhr.addEventListener('readystatechange', this.checkStreamClosed);
this.xhr.addEventListener('error', this.onStreamFailure);
this.xhr.addEventListener('abort', this.onStreamAbort);
this.xhr.open(this.method, this.url);
for (var header in this.headers) {
this.xhr.setRequestHeader(header, this.headers[header]);
}
this.xhr.withCredentials = this.withCredentials;
this.xhr.send(this.payload);
};
/**
* Closes the SSE connection.
*/
public close = () => {
if (this.readyState === this.CLOSED) {
return;
}
try {
this.xhr.abort();
this.xhr = null;
this.setReadyState(this.CLOSED);
} catch (e) {
console.error(e);
}
};
/**
* Processes incoming data from the SSE connection and dispatches events based on the received data.
*/
private onStreamProgress = (e: any) => {
if (!this.xhr) {
return;
}
if (this.xhr.status !== 200) {
this.onStreamFailure(e);
return;
}
if (this.readyState === this.CONNECTING) {
this.dispatchEvent(new CustomEvent('open'));
this.setReadyState(this.OPEN);
}
const data = this.xhr.responseText.substring(this.progress);
this.progress += data.length;
data.split(/(\r\n|\r|\n){2}/g).forEach((part: string) => {
if (part.trim().length === 0) {
this.dispatchEvent(this.parseEventChunk(this.chunk.trim()));
this.chunk = '';
} else {
this.chunk += part;
}
});
};
/**
* Parses a received SSE event chunk and constructs an event object based on the chunk data.
*/
private parseEventChunk = (chunk: string) => {
if (!chunk || chunk.length === 0) {
return null;
}
const e: any = { 'id': null, 'retry': null, 'data': '', 'event': 'message' };
chunk.split(/\n|\r\n|\r/).forEach((line: string) => {
line = line.trimRight();
const index = line.indexOf(this.FIELD_SEPARATOR);
if (index <= 0) {
// Line was either empty, or started with a separator and is a comment.
// Either way, ignore.
return;
}
const field = line.substring(0, index);
if (!(field in e)) {
return;
}
const value = line.substring(index + 1).trimLeft();
if (field === 'data') {
e[field] += value;
} else {
e[field] = value;
}
});
const event: any = new CustomEvent(e.event);
event.data = e.data;
event.id = e.id;
return event;
};
/**
* Handles the 'load' event for the SSE connection and processes the remaining data.
*/
private onStreamLoaded = (e: any) => {
this.onStreamProgress(e);
// Parse the last chunk.
this.dispatchEvent(this.parseEventChunk(this.chunk));
this.chunk = '';
};
/**
* Adds an event listener for a given event type.
*/
public addEventListener = (type: string, listener: any) => {
if (this.listeners[type] === undefined) {
this.listeners[type] = [];
}
if (this.listeners[type].indexOf(listener) === -1) {
this.listeners[type].push(listener);
}
};
/**
* Removes an event listener for a given event type.
*/
public removeEventListener = (type: string, listener: any) => {
if (this.listeners[type] === undefined) {
return;
}
const filtered: any[] = [];
this.listeners[type].forEach((element: any) => {
if (element !== listener) {
filtered.push(element);
}
});
if (filtered.length === 0) {
delete this.listeners[type];
} else {
this.listeners[type] = filtered;
}
};
/**
* Dispatches an event to all registered listeners.
*/
private dispatchEvent = (e: any) => {
if (!e) {
return true;
}
e.source = this;
const onHandler = 'on' + e.type;
if (this.hasOwnProperty(onHandler)) {
// @ts-ignore
this[onHandler].call(this, e);
if (e.defaultPrevented) {
return false;
}
}
if (this.listeners[e.type]) {
return this.listeners[e.type].every((callback: (arg0: any) => void) => {
callback(e);
return !e.defaultPrevented;
});
}
return true;
};
/**
* Sets the ready state of the SSE connection and dispatches a 'readystatechange' event.
*/
private setReadyState = (state: number) => {
const event = new CustomEvent<any>('readystatechange');
// @ts-ignore
event.readyState = state;
this.readyState = state;
this.dispatchEvent(event);
};
/**
* Handles an error during the SSE connection and dispatches an 'error' event.
*/
private onStreamFailure = (e: { currentTarget: { response: any; }; }) => {
const event = new CustomEvent('error');
// @ts-ignore
event.data = e.currentTarget.response;
this.dispatchEvent(event);
this.close();
}
/**
* Handles an abort event during the SSE connection and dispatches an 'abort' event.
*/
private onStreamAbort = (e: any) => {
this.dispatchEvent(new CustomEvent('abort'));
this.close();
}
/**
* Checks if the SSE connection is closed and sets the ready state to CLOSED if needed.
*/
private checkStreamClosed = () => {
if (!this.xhr) {
return;
}
if (this.xhr.readyState === XMLHttpRequest.DONE) {
this.setReadyState(this.CLOSED);
}
};
};