initial work on server-side token counting
This commit is contained in:
		| @@ -12,4 +12,8 @@ export async function basicHandler(req: express.Request, res: express.Response) | |||||||
|     }) |     }) | ||||||
|  |  | ||||||
|     res.json(response.data); |     res.json(response.data); | ||||||
|  |  | ||||||
|  |     const promptTokens = response.data.usage.prompt_tokens as number; | ||||||
|  |     const completionTokens = response.data.usage.completion_tokens as number; | ||||||
|  |     // console.log(`prompt tokens: ${promptTokens}, completion tokens: ${completionTokens}, model: ${req.body.model}`); | ||||||
| } | } | ||||||
							
								
								
									
										4
									
								
								server/src/endpoints/service-proxies/openai/message.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								server/src/endpoints/service-proxies/openai/message.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | export interface OpenAIMessage { | ||||||
|  |     role: string; | ||||||
|  |     content: string; | ||||||
|  | } | ||||||
| @@ -2,6 +2,7 @@ | |||||||
| import { EventSource } from "launchdarkly-eventsource"; | import { EventSource } from "launchdarkly-eventsource"; | ||||||
| import express from 'express'; | import express from 'express'; | ||||||
| import { apiKey } from "."; | import { apiKey } from "."; | ||||||
|  | import { countTokensForMessages } from "./tokenizer"; | ||||||
|  |  | ||||||
| export async function streamingHandler(req: express.Request, res: express.Response) { | export async function streamingHandler(req: express.Request, res: express.Response) { | ||||||
|     res.set({ |     res.set({ | ||||||
| @@ -10,6 +11,11 @@ export async function streamingHandler(req: express.Request, res: express.Respon | |||||||
|         Connection: 'keep-alive', |         Connection: 'keep-alive', | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
|  |     const messages = req.body.messages; | ||||||
|  |     const promptTokens = countTokensForMessages(messages); | ||||||
|  |  | ||||||
|  |     let completion = ''; | ||||||
|  |  | ||||||
|     const eventSource = new EventSource('https://api.openai.com/v1/chat/completions', { |     const eventSource = new EventSource('https://api.openai.com/v1/chat/completions', { | ||||||
|         method: "POST", |         method: "POST", | ||||||
|         headers: { |         headers: { | ||||||
| @@ -30,6 +36,26 @@ export async function streamingHandler(req: express.Request, res: express.Respon | |||||||
|         if (event.data === '[DONE]') { |         if (event.data === '[DONE]') { | ||||||
|             res.end(); |             res.end(); | ||||||
|             eventSource.close(); |             eventSource.close(); | ||||||
|  |  | ||||||
|  |             const totalTokens = countTokensForMessages([ | ||||||
|  |                 ...messages, | ||||||
|  |                 { | ||||||
|  |                     role: "assistant", | ||||||
|  |                     content: completion, | ||||||
|  |                 }, | ||||||
|  |             ]); | ||||||
|  |             const completionTokens = totalTokens - promptTokens; | ||||||
|  |             // console.log(`prompt tokens: ${promptTokens}, completion tokens: ${completionTokens}, model: ${req.body.model}`); | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|  |         try { | ||||||
|  |             const chunk = parseResponseChunk(event.data); | ||||||
|  |             if (chunk.choices && chunk.choices.length > 0) { | ||||||
|  |                 completion += chunk.choices[0]?.delta?.content || ''; | ||||||
|  |             } | ||||||
|  |         } catch (e) { | ||||||
|  |             console.error(e); | ||||||
|         } |         } | ||||||
|     }); |     }); | ||||||
|  |  | ||||||
| @@ -48,4 +74,23 @@ export async function streamingHandler(req: express.Request, res: express.Respon | |||||||
|     res.on('error', e => { |     res.on('error', e => { | ||||||
|         eventSource.close(); |         eventSource.close(); | ||||||
|     }); |     }); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | function parseResponseChunk(buffer: any) { | ||||||
|  |     const chunk = buffer.toString().replace('data: ', '').trim(); | ||||||
|  |  | ||||||
|  |     if (chunk === '[DONE]') { | ||||||
|  |         return { | ||||||
|  |             done: true, | ||||||
|  |         }; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const parsed = JSON.parse(chunk); | ||||||
|  |  | ||||||
|  |     return { | ||||||
|  |         id: parsed.id, | ||||||
|  |         done: false, | ||||||
|  |         choices: parsed.choices, | ||||||
|  |         model: parsed.model, | ||||||
|  |     }; | ||||||
| } | } | ||||||
							
								
								
									
										241
									
								
								server/src/endpoints/service-proxies/openai/tokenizer/bpe.ts
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										241
									
								
								server/src/endpoints/service-proxies/openai/tokenizer/bpe.ts
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,241 @@ | |||||||
|  | const MAX_NUM_THREADS = 128; | ||||||
|  |  | ||||||
|  | type MergeRange = { start: number, end: number }; | ||||||
|  |  | ||||||
|  | export class RankMap { | ||||||
|  |     private values = new Map<string, number>(); | ||||||
|  |  | ||||||
|  |     public static from(texts: string[]) { | ||||||
|  |         const map = new RankMap(); | ||||||
|  |         for (let i = 0; i < texts.length; i++) { | ||||||
|  |             map.values.set(texts[i], i); | ||||||
|  |         } | ||||||
|  |         return map; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     public set(bytes: Uint8Array, rank: number) { | ||||||
|  |         const key = Buffer.from(bytes).toString(); | ||||||
|  |         this.values.set(key, rank); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     public get(bytes: Uint8Array) { | ||||||
|  |         const key = Buffer.from(bytes).toString(); | ||||||
|  |         return this.values.get(key); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     public keys() { | ||||||
|  |         return Array.from(this.values.keys()).map(k => Buffer.from(k)); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     public inverted() { | ||||||
|  |         const inverted = new Map<number, Uint8Array>(); | ||||||
|  |         for (const [key, value] of Array.from(this.values.entries())) { | ||||||
|  |             inverted.set(value, new Uint8Array(Buffer.from(key))); | ||||||
|  |         } | ||||||
|  |         return inverted; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | function bytePairMerge(piece: Uint8Array, ranks: RankMap): MergeRange[] { | ||||||
|  |     let parts: MergeRange[] = Array.from({ length: piece.length }, (_, i) => ({ start: i, end: i + 1 })); | ||||||
|  |     while (true) { | ||||||
|  |         if (parts.length === 1) { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |         let minRank: [number, number] | null = null; | ||||||
|  |         for (let i = 0; i < parts.length - 1; i++) { | ||||||
|  |             const rank = ranks.get(piece.slice(parts[i].start, parts[i + 1].end)); | ||||||
|  |             if (rank === undefined) { | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |             if (minRank === null || rank < minRank[0]) { | ||||||
|  |                 minRank = [rank, i]; | ||||||
|  |             } | ||||||
|  |         } | ||||||
|  |         if (minRank !== null) { | ||||||
|  |             const [_, i] = minRank; | ||||||
|  |             parts[i] = { start: parts[i].start, end: parts[i + 1].end }; | ||||||
|  |             parts.splice(i + 1, 1); | ||||||
|  |         } else { | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  |     return parts; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | function bytePairEncode(piece: Uint8Array, ranks: RankMap): number[] { | ||||||
|  |     if (piece.length === 1) { | ||||||
|  |         return [ranks.get(piece)!]; | ||||||
|  |     } | ||||||
|  |     return bytePairMerge(piece, ranks).map((p) => ranks.get(piece.slice(p.start, p.end))!); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | function bytePairSplit(piece: Uint8Array, ranks: RankMap): Uint8Array[] { | ||||||
|  |     if (piece.length === 1) { | ||||||
|  |         return [piece]; | ||||||
|  |     } | ||||||
|  |     return bytePairMerge(piece, ranks).map((p) => piece.slice(p.start, p.end)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export class CoreBPE { | ||||||
|  |     encoder: RankMap; | ||||||
|  |     specialTokensEncoder: Map<string, number>; | ||||||
|  |     decoder: Map<number, Uint8Array>; | ||||||
|  |     specialTokensDecoder: Map<number, Uint8Array>; | ||||||
|  |     regexTls: RegExp[]; | ||||||
|  |     specialRegexTls: RegExp[]; | ||||||
|  |     sortedTokenBytes: Uint8Array[]; | ||||||
|  |  | ||||||
|  |     constructor( | ||||||
|  |         encoder: RankMap, | ||||||
|  |         specialTokensEncoder: Map<string, number>, | ||||||
|  |         regex: RegExp | ||||||
|  |     ) { | ||||||
|  |         const specialRegex = new RegExp( | ||||||
|  |             Array.from(specialTokensEncoder.keys()) | ||||||
|  |                 .map((s) => s.replace(/[.*+\-?^${}()|[\]\\]/g, "\\$&")) | ||||||
|  |                 .join("|") | ||||||
|  |         ); | ||||||
|  |  | ||||||
|  |         const decoder: Map<number, Uint8Array> = encoder.inverted(); | ||||||
|  |  | ||||||
|  |         const specialTokensDecoder: Map<number, Uint8Array> = new Map( | ||||||
|  |             Array.from(specialTokensEncoder.entries()).map(([k, v]) => [v, new Uint8Array(Buffer.from(k))]) | ||||||
|  |         ); | ||||||
|  |         const sortedTokenBytes: Uint8Array[] = Array.from(encoder.keys()); | ||||||
|  |         sortedTokenBytes.sort((a, b) => Buffer.compare(a, b)); | ||||||
|  |  | ||||||
|  |         this.encoder = encoder; | ||||||
|  |         this.specialTokensEncoder = specialTokensEncoder; | ||||||
|  |         this.decoder = decoder; | ||||||
|  |         this.specialTokensDecoder = specialTokensDecoder; | ||||||
|  |         this.regexTls = Array(MAX_NUM_THREADS).fill(regex); | ||||||
|  |         this.specialRegexTls = Array(MAX_NUM_THREADS).fill(specialRegex); | ||||||
|  |         this.sortedTokenBytes = sortedTokenBytes; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private _getTlRegex(): RegExp { | ||||||
|  |         return this.regexTls[Math.floor(Math.random() * MAX_NUM_THREADS)]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private _getTlSpecialRegex(): RegExp { | ||||||
|  |         return this.specialRegexTls[Math.floor(Math.random() * MAX_NUM_THREADS)]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private _decodeNative(tokens: number[]): Uint8Array { | ||||||
|  |         const ret: number[] = []; | ||||||
|  |         for (const token of tokens) { | ||||||
|  |             const tokenBytes = this.decoder.get(token) || this.specialTokensDecoder.get(token)!; | ||||||
|  |             ret.push(...Array.from(tokenBytes)); | ||||||
|  |         } | ||||||
|  |         return new Uint8Array(ret); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private _encodeOrdinaryNative(text: string): number[] { | ||||||
|  |         const regex = this._getTlRegex(); | ||||||
|  |         const ret: number[] = []; | ||||||
|  |         let match: RegExpExecArray | null; | ||||||
|  |         while ((match = regex.exec(text)) !== null) { | ||||||
|  |             const piece = new Uint8Array(Buffer.from(match[0])); | ||||||
|  |             const token = this.encoder.get(piece); | ||||||
|  |             if (token !== undefined) { | ||||||
|  |                 ret.push(token); | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  |             ret.push(...bytePairEncode(piece, this.encoder)); | ||||||
|  |         } | ||||||
|  |         return ret; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     private _encodeNative(text: string, allowedSpecial: Set<string>): [number[], number] { | ||||||
|  |         const specialRegex = this._getTlSpecialRegex(); | ||||||
|  |         const regex = this._getTlRegex(); | ||||||
|  |         const ret: number[] = []; | ||||||
|  |  | ||||||
|  |         let start = 0; | ||||||
|  |         let lastPieceTokenLen = 0; | ||||||
|  |         while (true) { | ||||||
|  |             let nextSpecial: RegExpExecArray | null; | ||||||
|  |             let startFind = start; | ||||||
|  |             while (true) { | ||||||
|  |                 nextSpecial = specialRegex.exec(text.slice(startFind)); | ||||||
|  |                 if (nextSpecial === null || allowedSpecial.has(nextSpecial[0])) { | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  |                 startFind = nextSpecial.index + 1; | ||||||
|  |             } | ||||||
|  |             const end = nextSpecial === null ? text.length : nextSpecial.index; | ||||||
|  |             let match: RegExpExecArray | null; | ||||||
|  |             while ((match = regex.exec(text.slice(start, end))) !== null) { | ||||||
|  |                 const piece = new Uint8Array(Buffer.from(match[0])); | ||||||
|  |                 const token = this.encoder.get(piece); | ||||||
|  |                 if (token !== undefined) { | ||||||
|  |                     lastPieceTokenLen = 1; | ||||||
|  |                     ret.push(token); | ||||||
|  |                     continue; | ||||||
|  |                 } | ||||||
|  |                 const tokens = bytePairEncode(piece, this.encoder); | ||||||
|  |                 lastPieceTokenLen = tokens.length; | ||||||
|  |                 ret.push(...tokens); | ||||||
|  |             } | ||||||
|  |  | ||||||
|  |             if (nextSpecial === null) { | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             const piece = nextSpecial[0]; | ||||||
|  |             const token = this.specialTokensEncoder.get(piece)!; | ||||||
|  |             ret.push(token); | ||||||
|  |             start = nextSpecial.index + piece.length; | ||||||
|  |             lastPieceTokenLen = 0; | ||||||
|  |         } | ||||||
|  |         return [ret, lastPieceTokenLen]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     encodeOrdinary(text: string): number[] { | ||||||
|  |         return this._encodeOrdinaryNative(text); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     encode(text: string, allowedSpecial: Set<string>): number[] { | ||||||
|  |         return this._encodeNative(text, allowedSpecial)[0]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     encodeWithUnstable(text: string, allowedSpecial: Set<string>): [number[], Set<number[]>] { | ||||||
|  |         throw new Error("Not implemented"); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     encodeSingleToken(piece: Uint8Array): number { | ||||||
|  |         const token = this.encoder.get(piece); | ||||||
|  |         if (token !== undefined) { | ||||||
|  |             return token; | ||||||
|  |         } | ||||||
|  |         const pieceStr = Buffer.from(piece).toString("utf-8"); | ||||||
|  |         if (this.specialTokensEncoder.has(pieceStr)) { | ||||||
|  |             return this.specialTokensEncoder.get(pieceStr)!; | ||||||
|  |         } | ||||||
|  |         throw new Error("Key not found"); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     encodeSinglePiece(piece: Uint8Array): number[] { | ||||||
|  |         const token = this.encoder.get(piece); | ||||||
|  |         if (token !== undefined) { | ||||||
|  |             return [token]; | ||||||
|  |         } | ||||||
|  |         return bytePairEncode(piece, this.encoder); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     decodeBytes(tokens: number[]): Uint8Array { | ||||||
|  |         return this._decodeNative(tokens); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     decodeSingleTokenBytes(token: number): Uint8Array { | ||||||
|  |         const bytes = this.decoder.get(token) || this.specialTokensDecoder.get(token); | ||||||
|  |         if (bytes !== undefined) { | ||||||
|  |             return bytes; | ||||||
|  |         } | ||||||
|  |         throw new Error("Key not found"); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     tokenByteValues(): Uint8Array[] { | ||||||
|  |         return this.sortedTokenBytes; | ||||||
|  |     } | ||||||
|  | } | ||||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | import { CoreBPE, RankMap } from "./bpe"; | ||||||
|  | import fs from 'fs'; | ||||||
|  | import path from 'path'; | ||||||
|  | import { OpenAIMessage } from "../message"; | ||||||
|  |  | ||||||
|  | const ranks = JSON.parse(fs.readFileSync(path.join(__dirname, './cl100k_base.json'), 'utf8')); | ||||||
|  |  | ||||||
|  | const special_tokens: any = { | ||||||
|  |     "<|endoftext|>": 100257, | ||||||
|  |     "<|fim_prefix|>": 100258, | ||||||
|  |     "<|fim_middle|>": 100259, | ||||||
|  |     "<|fim_suffix|>": 100260, | ||||||
|  |     "<|endofprompt|>": 100276, | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | const special_tokens_map = new Map<string, number>(); | ||||||
|  | for (const text of Object.keys(special_tokens)) { | ||||||
|  |     special_tokens_map.set(text, special_tokens[text]); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const pattern = /('s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+/giu; | ||||||
|  |  | ||||||
|  | const tokenizer = new CoreBPE(RankMap.from(ranks), special_tokens_map, pattern); | ||||||
|  |  | ||||||
|  | const overheadTokens = { | ||||||
|  |     perMessage: 5, | ||||||
|  |     perPrompt: 2, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const tokenCache = new Map<string, number>(); | ||||||
|  |  | ||||||
|  | export function countTokensForText(text: string) { | ||||||
|  |     const cacheKey = text; | ||||||
|  |     if (tokenCache.has(cacheKey)) { | ||||||
|  |         return tokenCache.get(cacheKey)!; | ||||||
|  |     } | ||||||
|  |     let t1 = Date.now(); | ||||||
|  |     const tokens = tokenizer.encodeOrdinary(text).length; | ||||||
|  |     tokenCache.set(cacheKey, tokens); | ||||||
|  |     return tokens; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export function countTokensForMessage(message: OpenAIMessage) { | ||||||
|  |     return countTokensForText(message.content) + overheadTokens.perMessage; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | export function countTokensForMessages(messages: OpenAIMessage[]) { | ||||||
|  |     let tokens = overheadTokens.perPrompt; | ||||||
|  |     for (const m of messages) { | ||||||
|  |         tokens += countTokensForMessage(m); | ||||||
|  |     } | ||||||
|  |     return tokens; | ||||||
|  | } | ||||||
		Reference in New Issue
	
	Block a user