diff --git a/api/app/clients/ChatGPTClient.js b/api/app/clients/ChatGPTClient.js index 6a7ba7b9896..5450300a178 100644 --- a/api/app/clients/ChatGPTClient.js +++ b/api/app/clients/ChatGPTClient.js @@ -13,7 +13,6 @@ const { const { extractBaseURL, constructAzureURL, genAzureChatCompletion } = require('~/utils'); const { createContextHandlers } = require('./prompts'); const { createCoherePayload } = require('./llm'); -const { Agent, ProxyAgent } = require('undici'); const BaseClient = require('./BaseClient'); const { logger } = require('~/config'); @@ -186,10 +185,6 @@ class ChatGPTClient extends BaseClient { headers: { 'Content-Type': 'application/json', }, - dispatcher: new Agent({ - bodyTimeout: 0, - headersTimeout: 0, - }), }; if (this.isVisionModel) { @@ -275,10 +270,6 @@ class ChatGPTClient extends BaseClient { opts.headers['X-Title'] = 'LibreChat'; } - if (this.options.proxy) { - opts.dispatcher = new ProxyAgent(this.options.proxy); - } - /* hacky fixes for Mistral AI API: - Re-orders system message to the top of the messages payload, as not allowed anywhere else - If there is only one message and it's a system message, change the role to user diff --git a/api/app/clients/GoogleClient.js b/api/app/clients/GoogleClient.js index 8fce279bf14..8ae20accd57 100644 --- a/api/app/clients/GoogleClient.js +++ b/api/app/clients/GoogleClient.js @@ -1,22 +1,23 @@ const { google } = require('googleapis'); -const { Agent, ProxyAgent } = require('undici'); +const { concat } = require('@langchain/core/utils/stream'); const { ChatVertexAI } = require('@langchain/google-vertexai'); -const { GoogleVertexAI } = require('@langchain/google-vertexai'); -const { ChatGoogleVertexAI } = require('@langchain/google-vertexai'); const { ChatGoogleGenerativeAI } = require('@langchain/google-genai'); const { GoogleGenerativeAI: GenAI } = require('@google/generative-ai'); -const { AIMessage, HumanMessage, SystemMessage } = require('@langchain/core/messages'); +const { HumanMessage, SystemMessage } = require('@langchain/core/messages'); const { + googleGenConfigSchema, validateVisionModel, getResponseSender, endpointSettings, EModelEndpoint, VisionModes, + ErrorTypes, Constants, AuthKeys, } = require('librechat-data-provider'); const { encodeAndFormat } = require('~/server/services/Files/images'); const Tokenizer = require('~/server/services/Tokenizer'); +const { spendTokens } = require('~/models/spendTokens'); const { getModelMaxTokens } = require('~/utils'); const { sleep } = require('~/server/utils'); const { logger } = require('~/config'); @@ -49,9 +50,10 @@ class GoogleClient extends BaseClient { const serviceKey = creds[AuthKeys.GOOGLE_SERVICE_KEY] ?? {}; this.serviceKey = serviceKey && typeof serviceKey === 'string' ? JSON.parse(serviceKey) : serviceKey ?? {}; + /** @type {string | null | undefined} */ + this.project_id = this.serviceKey.project_id; this.client_email = this.serviceKey.client_email; this.private_key = this.serviceKey.private_key; - this.project_id = this.serviceKey.project_id; this.access_token = null; this.apiKey = creds[AuthKeys.GOOGLE_API_KEY]; @@ -60,6 +62,15 @@ class GoogleClient extends BaseClient { this.authHeader = options.authHeader; + /** @type {UsageMetadata | undefined} */ + this.usage; + /** The key for the usage object's input tokens + * @type {string} */ + this.inputTokensKey = 'input_tokens'; + /** The key for the usage object's output tokens + * @type {string} */ + this.outputTokensKey = 'output_tokens'; + if (options.skipSetOptions) { return; } @@ -119,22 +130,13 @@ class GoogleClient extends BaseClient { this.options = options; } - this.options.examples = (this.options.examples ?? []) - .filter((ex) => ex) - .filter((obj) => obj.input.content !== '' && obj.output.content !== ''); - this.modelOptions = this.options.modelOptions || {}; this.options.attachments?.then((attachments) => this.checkVisionRequest(attachments)); /** @type {boolean} Whether using a "GenerativeAI" Model */ - this.isGenerativeModel = this.modelOptions.model.includes('gemini'); - const { isGenerativeModel } = this; - this.isChatModel = !isGenerativeModel && this.modelOptions.model.includes('chat'); - const { isChatModel } = this; - this.isTextModel = - !isGenerativeModel && !isChatModel && /code|text/.test(this.modelOptions.model); - const { isTextModel } = this; + this.isGenerativeModel = + this.modelOptions.model.includes('gemini') || this.modelOptions.model.includes('learnlm'); this.maxContextTokens = this.options.maxContextTokens ?? @@ -170,40 +172,18 @@ class GoogleClient extends BaseClient { this.userLabel = this.options.userLabel || 'User'; this.modelLabel = this.options.modelLabel || 'Assistant'; - if (isChatModel || isGenerativeModel) { - // Use these faux tokens to help the AI understand the context since we are building the chat log ourselves. - // Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason, - // without tripping the stop sequences, so I'm using "||>" instead. - this.startToken = '||>'; - this.endToken = ''; - } else if (isTextModel) { - this.startToken = '||>'; - this.endToken = ''; - } else { - // Previously I was trying to use "<|endoftext|>" but there seems to be some bug with OpenAI's token counting - // system that causes only the first "<|endoftext|>" to be counted as 1 token, and the rest are not treated - // as a single token. So we're using this instead. - this.startToken = '||>'; - this.endToken = ''; - } - - if (!this.modelOptions.stop) { - const stopTokens = [this.startToken]; - if (this.endToken && this.endToken !== this.startToken) { - stopTokens.push(this.endToken); - } - stopTokens.push(`\n${this.userLabel}:`); - stopTokens.push('<|diff_marker|>'); - // I chose not to do one for `modelLabel` because I've never seen it happen - this.modelOptions.stop = stopTokens; - } - if (this.options.reverseProxyUrl) { this.completionsUrl = this.options.reverseProxyUrl; } else { this.completionsUrl = this.constructUrl(); } + let promptPrefix = (this.options.promptPrefix ?? '').trim(); + if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { + promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); + } + this.options.promptPrefix = promptPrefix; + this.initializeClient(); return this; } @@ -336,7 +316,6 @@ class GoogleClient extends BaseClient { messages: [new HumanMessage(formatMessage({ message: latestMessage }))], }, ], - parameters: this.modelOptions, }; return { prompt: payload }; } @@ -352,23 +331,58 @@ class GoogleClient extends BaseClient { return { prompt: formattedMessages }; } - async buildMessages(messages = [], parentMessageId) { + /** + * @param {TMessage[]} [messages=[]] + * @param {string} [parentMessageId] + */ + async buildMessages(_messages = [], parentMessageId) { if (!this.isGenerativeModel && !this.project_id) { - throw new Error( - '[GoogleClient] a Service Account JSON Key is required for PaLM 2 and Codey models (Vertex AI)', - ); + throw new Error('[GoogleClient] PaLM 2 and Codey models are no longer supported.'); } - if (!this.project_id && !EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)) { - return await this.buildGenerativeMessages(messages); + if (this.options.promptPrefix) { + const instructionsTokenCount = this.getTokenCount(this.options.promptPrefix); + + this.maxContextTokens = this.maxContextTokens - instructionsTokenCount; + if (this.maxContextTokens < 0) { + const info = `${instructionsTokenCount} / ${this.maxContextTokens}`; + const errorMessage = `{ "type": "${ErrorTypes.INPUT_LENGTH}", "info": "${info}" }`; + logger.warn(`Instructions token count exceeds max context (${info}).`); + throw new Error(errorMessage); + } } - if (this.options.attachments && this.isGenerativeModel) { - return this.buildVisionMessages(messages, parentMessageId); + for (let i = 0; i < _messages.length; i++) { + const message = _messages[i]; + if (!message.tokenCount) { + _messages[i].tokenCount = this.getTokenCountForMessage({ + role: message.isCreatedByUser ? 'user' : 'assistant', + content: message.content ?? message.text, + }); + } + } + + const { + payload: messages, + tokenCountMap, + promptTokens, + } = await this.handleContextStrategy({ + orderedMessages: _messages, + formattedMessages: _messages, + }); + + if (!this.project_id && !EXCLUDED_GENAI_MODELS.test(this.modelOptions.model)) { + const result = await this.buildGenerativeMessages(messages); + result.tokenCountMap = tokenCountMap; + result.promptTokens = promptTokens; + return result; } - if (this.isTextModel) { - return this.buildMessagesPrompt(messages, parentMessageId); + if (this.options.attachments && this.isGenerativeModel) { + const result = this.buildVisionMessages(messages, parentMessageId); + result.tokenCountMap = tokenCountMap; + result.promptTokens = promptTokens; + return result; } let payload = { @@ -380,25 +394,14 @@ class GoogleClient extends BaseClient { .map((message) => formatMessage({ message, langChain: true })), }, ], - parameters: this.modelOptions, }; - let promptPrefix = (this.options.promptPrefix ?? '').trim(); - if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { - promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); - } - - if (promptPrefix) { - payload.instances[0].context = promptPrefix; - } - - if (this.options.examples.length > 0) { - payload.instances[0].examples = this.options.examples; + if (this.options.promptPrefix) { + payload.instances[0].context = this.options.promptPrefix; } logger.debug('[GoogleClient] buildMessages', payload); - - return { prompt: payload }; + return { prompt: payload, tokenCountMap, promptTokens }; } async buildMessagesPrompt(messages, parentMessageId) { @@ -412,10 +415,7 @@ class GoogleClient extends BaseClient { parentMessageId, }); - const formattedMessages = orderedMessages.map((message) => ({ - author: message.isCreatedByUser ? this.userLabel : this.modelLabel, - content: message?.content ?? message.text, - })); + const formattedMessages = orderedMessages.map(this.formatMessages()); let lastAuthor = ''; let groupedMessages = []; @@ -444,16 +444,6 @@ class GoogleClient extends BaseClient { } let promptPrefix = (this.options.promptPrefix ?? '').trim(); - if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { - promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); - } - if (promptPrefix) { - // If the prompt prefix doesn't end with the end token, add it. - if (!promptPrefix.endsWith(`${this.endToken}`)) { - promptPrefix = `${promptPrefix.trim()}${this.endToken}\n\n`; - } - promptPrefix = `\nContext:\n${promptPrefix}`; - } if (identityPrefix) { promptPrefix = `${identityPrefix}${promptPrefix}`; @@ -490,7 +480,7 @@ class GoogleClient extends BaseClient { isCreatedByUser || !isEdited ? `\n\n${message.author}:` : `${promptPrefix}\n\n${message.author}:`; - const messageString = `${messagePrefix}\n${message.content}${this.endToken}\n`; + const messageString = `${messagePrefix}\n${message.content}\n`; let newPromptBody = `${messageString}${promptBody}`; context.unshift(message); @@ -556,34 +546,6 @@ class GoogleClient extends BaseClient { return { prompt, context }; } - async _getCompletion(payload, abortController = null) { - if (!abortController) { - abortController = new AbortController(); - } - const { debug } = this.options; - const url = this.completionsUrl; - if (debug) { - logger.debug('GoogleClient _getCompletion', { url, payload }); - } - const opts = { - method: 'POST', - agent: new Agent({ - bodyTimeout: 0, - headersTimeout: 0, - }), - signal: abortController.signal, - }; - - if (this.options.proxy) { - opts.agent = new ProxyAgent(this.options.proxy); - } - - const client = await this.getClient(); - const res = await client.request({ url, method: 'POST', data: payload }); - logger.debug('GoogleClient _getCompletion', { res }); - return res.data; - } - createLLM(clientOptions) { const model = clientOptions.modelName ?? clientOptions.model; clientOptions.location = loc; @@ -602,33 +564,20 @@ class GoogleClient extends BaseClient { } } - if (this.project_id && this.isTextModel) { - logger.debug('Creating Google VertexAI client'); - return new GoogleVertexAI(clientOptions); - } else if (this.project_id && this.isChatModel) { - logger.debug('Creating Chat Google VertexAI client'); - return new ChatGoogleVertexAI(clientOptions); - } else if (this.project_id) { + if (this.project_id != null) { logger.debug('Creating VertexAI client'); return new ChatVertexAI(clientOptions); } else if (!EXCLUDED_GENAI_MODELS.test(model)) { logger.debug('Creating GenAI client'); - return new GenAI(this.apiKey).getGenerativeModel({ ...clientOptions, model }, requestOptions); + return new GenAI(this.apiKey).getGenerativeModel({ model }, requestOptions); } logger.debug('Creating Chat Google Generative AI client'); return new ChatGoogleGenerativeAI({ ...clientOptions, apiKey: this.apiKey }); } - async getCompletion(_payload, options = {}) { - const { parameters, instances } = _payload; - const { onProgress, abortController } = options; - const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; - const { messages: _messages, context, examples: _examples } = instances?.[0] ?? {}; - - let examples; - - let clientOptions = { ...parameters, maxRetries: 2 }; + initializeClient() { + let clientOptions = { ...this.modelOptions, maxRetries: 2 }; if (this.project_id) { clientOptions['authOptions'] = { @@ -639,53 +588,34 @@ class GoogleClient extends BaseClient { }; } - if (!parameters) { - clientOptions = { ...clientOptions, ...this.modelOptions }; - } - if (this.isGenerativeModel && !this.project_id) { clientOptions.modelName = clientOptions.model; delete clientOptions.model; } - if (_examples && _examples.length) { - examples = _examples - .map((ex) => { - const { input, output } = ex; - if (!input || !output) { - return undefined; - } - return { - input: new HumanMessage(input.content), - output: new AIMessage(output.content), - }; - }) - .filter((ex) => ex); - - clientOptions.examples = examples; - } + this.client = this.createLLM(clientOptions); + return this.client; + } - const model = this.createLLM(clientOptions); + async getCompletion(_payload, options = {}) { + const safetySettings = this.getSafetySettings(); + const { onProgress, abortController } = options; + const streamRate = this.options.streamRate ?? Constants.DEFAULT_STREAM_RATE; + const modelName = this.modelOptions.modelName ?? this.modelOptions.model ?? ''; let reply = ''; - const messages = this.isTextModel ? _payload.trim() : _messages; - - if (!this.isVisionModel && context && messages?.length > 0) { - messages.unshift(new SystemMessage(context)); - } - const modelName = clientOptions.modelName ?? clientOptions.model ?? ''; if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) { - const client = model; + /** @type {GenAI} */ + const client = this.client; + /** @type {GenerateContentRequest} */ const requestOptions = { + safetySettings, contents: _payload, + generationConfig: googleGenConfigSchema.parse(this.modelOptions), }; - let promptPrefix = (this.options.promptPrefix ?? '').trim(); - if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { - promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); - } - + const promptPrefix = (this.options.promptPrefix ?? '').trim(); if (promptPrefix.length) { requestOptions.systemInstruction = { parts: [ @@ -696,11 +626,14 @@ class GoogleClient extends BaseClient { }; } - requestOptions.safetySettings = _payload.safetySettings; - const delay = modelName.includes('flash') ? 8 : 15; + /** @type {GenAIUsageMetadata} */ + let usageMetadata; const result = await client.generateContentStream(requestOptions); for await (const chunk of result.stream) { + usageMetadata = !usageMetadata + ? chunk?.usageMetadata + : Object.assign(usageMetadata, chunk?.usageMetadata); const chunkText = chunk.text(); await this.generateTextStream(chunkText, onProgress, { delay, @@ -708,12 +641,29 @@ class GoogleClient extends BaseClient { reply += chunkText; await sleep(streamRate); } + + if (usageMetadata) { + this.usage = { + input_tokens: usageMetadata.promptTokenCount, + output_tokens: usageMetadata.candidatesTokenCount, + }; + } return reply; } - const stream = await model.stream(messages, { + const { instances } = _payload; + const { messages: messages, context } = instances?.[0] ?? {}; + + if (!this.isVisionModel && context && messages?.length > 0) { + messages.unshift(new SystemMessage(context)); + } + + /** @type {import('@langchain/core/messages').AIMessageChunk['usage_metadata']} */ + let usageMetadata; + const stream = await this.client.stream(messages, { signal: abortController.signal, - safetySettings: _payload.safetySettings, + streamUsage: true, + safetySettings, }); let delay = this.options.streamRate || 8; @@ -728,6 +678,9 @@ class GoogleClient extends BaseClient { } for await (const chunk of stream) { + usageMetadata = !usageMetadata + ? chunk?.usage_metadata + : concat(usageMetadata, chunk?.usage_metadata); const chunkText = chunk?.content ?? chunk; await this.generateTextStream(chunkText, onProgress, { delay, @@ -735,88 +688,114 @@ class GoogleClient extends BaseClient { reply += chunkText; } + if (usageMetadata) { + this.usage = usageMetadata; + } return reply; } /** - * Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming + * Get stream usage as returned by this client's API response. + * @returns {UsageMetadata} The stream usage object. */ - async titleChatCompletion(_payload, options = {}) { - const { abortController } = options; - const { parameters, instances } = _payload; - const { messages: _messages, examples: _examples } = instances?.[0] ?? {}; - - let clientOptions = { ...parameters, maxRetries: 2 }; + getStreamUsage() { + return this.usage; + } - logger.debug('Initialized title client options'); + /** + * Calculates the correct token count for the current user message based on the token count map and API usage. + * Edge case: If the calculation results in a negative value, it returns the original estimate. + * If revisiting a conversation with a chat history entirely composed of token estimates, + * the cumulative token count going forward should become more accurate as the conversation progresses. + * @param {Object} params - The parameters for the calculation. + * @param {Record} params.tokenCountMap - A map of message IDs to their token counts. + * @param {string} params.currentMessageId - The ID of the current message to calculate. + * @param {UsageMetadata} params.usage - The usage object returned by the API. + * @returns {number} The correct token count for the current user message. + */ + calculateCurrentTokenCount({ tokenCountMap, currentMessageId, usage }) { + const originalEstimate = tokenCountMap[currentMessageId] || 0; - if (this.project_id) { - clientOptions['authOptions'] = { - credentials: { - ...this.serviceKey, - }, - projectId: this.project_id, - }; + if (!usage || typeof usage.input_tokens !== 'number') { + return originalEstimate; } - if (!parameters) { - clientOptions = { ...clientOptions, ...this.modelOptions }; - } + tokenCountMap[currentMessageId] = 0; + const totalTokensFromMap = Object.values(tokenCountMap).reduce((sum, count) => { + const numCount = Number(count); + return sum + (isNaN(numCount) ? 0 : numCount); + }, 0); + const totalInputTokens = usage.input_tokens ?? 0; + const currentMessageTokens = totalInputTokens - totalTokensFromMap; + return currentMessageTokens > 0 ? currentMessageTokens : originalEstimate; + } - if (this.isGenerativeModel && !this.project_id) { - clientOptions.modelName = clientOptions.model; - delete clientOptions.model; - } + /** + * @param {object} params + * @param {number} params.promptTokens + * @param {number} params.completionTokens + * @param {UsageMetadata} [params.usage] + * @param {string} [params.model] + * @param {string} [params.context='message'] + * @returns {Promise} + */ + async recordTokenUsage({ promptTokens, completionTokens, model, context = 'message' }) { + await spendTokens( + { + context, + user: this.user ?? this.options.req?.user?.id, + conversationId: this.conversationId, + model: model ?? this.modelOptions.model, + endpointTokenConfig: this.options.endpointTokenConfig, + }, + { promptTokens, completionTokens }, + ); + } - const model = this.createLLM(clientOptions); + /** + * Stripped-down logic for generating a title. This uses the non-streaming APIs, since the user does not see titles streaming + */ + async titleChatCompletion(_payload, options = {}) { + const { abortController } = options; + const safetySettings = this.getSafetySettings(); let reply = ''; - const messages = this.isTextModel ? _payload.trim() : _messages; - const modelName = clientOptions.modelName ?? clientOptions.model ?? ''; - if (!EXCLUDED_GENAI_MODELS.test(modelName) && !this.project_id) { + const model = this.modelOptions.modelName ?? this.modelOptions.model ?? ''; + if (!EXCLUDED_GENAI_MODELS.test(model) && !this.project_id) { logger.debug('Identified titling model as GenAI version'); /** @type {GenerativeModel} */ - const client = model; + const client = this.client; const requestOptions = { contents: _payload, + safetySettings, + generationConfig: { + temperature: 0.5, + }, }; - let promptPrefix = (this.options.promptPrefix ?? '').trim(); - if (typeof this.options.artifactsPrompt === 'string' && this.options.artifactsPrompt) { - promptPrefix = `${promptPrefix ?? ''}\n${this.options.artifactsPrompt}`.trim(); - } - - if (this.options?.promptPrefix?.length) { - requestOptions.systemInstruction = { - parts: [ - { - text: promptPrefix, - }, - ], - }; - } - - const safetySettings = _payload.safetySettings; - requestOptions.safetySettings = safetySettings; - const result = await client.generateContent(requestOptions); - reply = result.response?.text(); - return reply; } else { - logger.debug('Beginning titling'); - const safetySettings = _payload.safetySettings; - - const titleResponse = await model.invoke(messages, { + const { instances } = _payload; + const { messages } = instances?.[0] ?? {}; + const titleResponse = await this.client.invoke(messages, { signal: abortController.signal, timeout: 7000, - safetySettings: safetySettings, + safetySettings, }); + if (titleResponse.usage_metadata) { + await this.recordTokenUsage({ + model, + promptTokens: titleResponse.usage_metadata.input_tokens, + completionTokens: titleResponse.usage_metadata.output_tokens, + context: 'title', + }); + } + reply = titleResponse.content; - // TODO: RECORD TOKEN USAGE return reply; } } @@ -840,15 +819,19 @@ class GoogleClient extends BaseClient { }, ]); + const model = process.env.GOOGLE_TITLE_MODEL ?? this.modelOptions.model; + const availableModels = this.options.modelsConfig?.[EModelEndpoint.google]; + this.isVisionModel = validateVisionModel({ model, availableModels }); + if (this.isVisionModel) { logger.warn( `Current vision model does not support titling without an attachment; falling back to default model ${settings.model.default}`, ); - - payload.parameters = { ...payload.parameters, model: settings.model.default }; + this.modelOptions.model = settings.model.default; } try { + this.initializeClient(); title = await this.titleChatCompletion(payload, { abortController: new AbortController(), onProgress: () => {}, @@ -865,6 +848,7 @@ class GoogleClient extends BaseClient { endpointType: null, artifacts: this.options.artifacts, promptPrefix: this.options.promptPrefix, + maxContextTokens: this.options.maxContextTokens, modelLabel: this.options.modelLabel, iconURL: this.options.iconURL, greeting: this.options.greeting, @@ -878,8 +862,6 @@ class GoogleClient extends BaseClient { } async sendCompletion(payload, opts = {}) { - payload.safetySettings = this.getSafetySettings(); - let reply = ''; reply = await this.getCompletion(payload, opts); return reply.trim(); @@ -931,6 +913,22 @@ class GoogleClient extends BaseClient { return 'cl100k_base'; } + async getVertexTokenCount(text) { + /** @type {ChatVertexAI} */ + const client = this.client ?? this.initializeClient(); + const connection = client.connection; + const gAuthClient = connection.client; + const tokenEndpoint = `https://${connection._endpoint}/${connection.apiVersion}/projects/${this.project_id}/locations/${connection._location}/publishers/google/models/${connection.model}/:countTokens`; + const result = await gAuthClient.request({ + url: tokenEndpoint, + method: 'POST', + data: { + contents: [{ role: 'user', parts: [{ text }] }], + }, + }); + return result; + } + /** * Returns the token count of a given text. It also checks and resets the tokenizers if necessary. * @param {string} text - The text to get the token count for. diff --git a/api/package.json b/api/package.json index 945fe5e1bc1..ca5b2d59de6 100644 --- a/api/package.json +++ b/api/package.json @@ -103,7 +103,6 @@ "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", - "undici": "^7.2.3", "winston": "^3.11.0", "winston-daily-rotate-file": "^4.7.1", "zod": "^3.22.4" diff --git a/api/server/services/Endpoints/google/build.js b/api/server/services/Endpoints/google/build.js index 45f11940ed9..11b048694f5 100644 --- a/api/server/services/Endpoints/google/build.js +++ b/api/server/services/Endpoints/google/build.js @@ -11,6 +11,7 @@ const buildOptions = (endpoint, parsedBody) => { greeting, spec, artifacts, + maxContextTokens, ...modelOptions } = parsedBody; const endpointOption = removeNullishValues({ @@ -22,6 +23,7 @@ const buildOptions = (endpoint, parsedBody) => { iconURL, greeting, spec, + maxContextTokens, modelOptions, }); diff --git a/api/server/services/Endpoints/google/llm.js b/api/server/services/Endpoints/google/llm.js index f19d0539c72..ae5e268ac3a 100644 --- a/api/server/services/Endpoints/google/llm.js +++ b/api/server/services/Endpoints/google/llm.js @@ -1,9 +1,6 @@ const { Providers } = require('@librechat/agents'); const { AuthKeys } = require('librechat-data-provider'); -// Example internal constant from your code -const EXCLUDED_GENAI_MODELS = /gemini-(?:1\.0|1-0|pro)/; - /** * * @param {boolean} isGemini2 @@ -89,22 +86,12 @@ function getLLMConfig(credentials, options = {}) { /** Used only for Safety Settings */ const isGemini2 = llmConfig.model.includes('gemini-2.0') && !llmConfig.model.includes('thinking'); - const isGenerativeModel = llmConfig.model.includes('gemini'); - const isChatModel = !isGenerativeModel && llmConfig.model.includes('chat'); - const isTextModel = !isGenerativeModel && !isChatModel && /code|text/.test(llmConfig.model); - llmConfig.safetySettings = getSafetySettings(isGemini2); let provider; - if (project_id && isTextModel) { - provider = Providers.VERTEXAI; - } else if (project_id && isChatModel) { + if (project_id) { provider = Providers.VERTEXAI; - } else if (project_id) { - provider = Providers.VERTEXAI; - } else if (!EXCLUDED_GENAI_MODELS.test(llmConfig.model)) { - provider = Providers.GOOGLE; } else { provider = Providers.GOOGLE; } diff --git a/api/typedefs.js b/api/typedefs.js index 186c0e4a528..b1960f4cb60 100644 --- a/api/typedefs.js +++ b/api/typedefs.js @@ -155,6 +155,18 @@ * @memberof typedefs */ +/** + * @exports GenerateContentRequest + * @typedef {import('@google/generative-ai').GenerateContentRequest} GenerateContentRequest + * @memberof typedefs + */ + +/** + * @exports GenAIUsageMetadata + * @typedef {import('@google/generative-ai').UsageMetadata} GenAIUsageMetadata + * @memberof typedefs + */ + /** * @exports AssistantStreamEvent * @typedef {import('openai').default.Beta.AssistantStreamEvent} AssistantStreamEvent diff --git a/client/src/components/Endpoints/MessageEndpointIcon.tsx b/client/src/components/Endpoints/MessageEndpointIcon.tsx index 5c6b35bf7f3..a4d4488579a 100644 --- a/client/src/components/Endpoints/MessageEndpointIcon.tsx +++ b/client/src/components/Endpoints/MessageEndpointIcon.tsx @@ -34,7 +34,10 @@ function getOpenAIColor(_model: string | null | undefined) { function getGoogleIcon(model: string | null | undefined, size: number) { if (model?.toLowerCase().includes('code') === true) { return ; - } else if (model?.toLowerCase().includes('gemini') === true) { + } else if ( + model?.toLowerCase().includes('gemini') === true || + model?.toLowerCase().includes('learnlm') === true + ) { return ; } else { return ; @@ -44,7 +47,10 @@ function getGoogleIcon(model: string | null | undefined, size: number) { function getGoogleModelName(model: string | null | undefined) { if (model?.toLowerCase().includes('code') === true) { return 'Codey'; - } else if (model?.toLowerCase().includes('gemini') === true) { + } else if ( + model?.toLowerCase().includes('gemini') === true || + model?.toLowerCase().includes('learnlm') === true + ) { return 'Gemini'; } else { return 'PaLM2'; diff --git a/client/src/components/SidePanel/Parameters/DynamicInput.tsx b/client/src/components/SidePanel/Parameters/DynamicInput.tsx index 36d7fd179aa..cba8dd66040 100644 --- a/client/src/components/SidePanel/Parameters/DynamicInput.tsx +++ b/client/src/components/SidePanel/Parameters/DynamicInput.tsx @@ -48,12 +48,15 @@ function DynamicInput({ const handleInputChange = (e: React.ChangeEvent) => { const value = e.target.value; - if (type === 'number') { - if (!isNaN(Number(value))) { - setInputValue(e, true); - } - } else { + if (type !== 'number') { setInputValue(e); + return; + } + + if (value === '') { + setInputValue(e); + } else if (!isNaN(Number(value))) { + setInputValue(e, true); } }; diff --git a/client/src/hooks/Input/useTextarea.ts b/client/src/hooks/Input/useTextarea.ts index 3a74a5ac5e9..ab881d7417a 100644 --- a/client/src/hooks/Input/useTextarea.ts +++ b/client/src/hooks/Input/useTextarea.ts @@ -191,6 +191,9 @@ export default function useTextarea({ const isNonShiftEnter = e.key === 'Enter' && !e.shiftKey; const isCtrlEnter = e.key === 'Enter' && (e.ctrlKey || e.metaKey); + // NOTE: isComposing and e.key behave differently in Safari compared to other browsers, forcing us to use e.keyCode instead + const isComposingInput = isComposing.current || e.key === 'Process' || e.keyCode === 229; + if (isNonShiftEnter && filesLoading) { e.preventDefault(); } @@ -204,7 +207,7 @@ export default function useTextarea({ !enterToSend && !isCtrlEnter && textAreaRef.current && - !isComposing.current + !isComposingInput ) { e.preventDefault(); insertTextAtCursor(textAreaRef.current, '\n'); @@ -212,7 +215,7 @@ export default function useTextarea({ return; } - if ((isNonShiftEnter || isCtrlEnter) && !isComposing.current) { + if ((isNonShiftEnter || isCtrlEnter) && !isComposingInput) { const globalAudio = document.getElementById(globalAudioId) as HTMLAudioElement | undefined; if (globalAudio) { console.log('Unmuting global audio'); diff --git a/package-lock.json b/package-lock.json index 5b25204141e..5f670c1c77d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -112,7 +112,6 @@ "tiktoken": "^1.0.15", "traverse": "^0.6.7", "ua-parser-js": "^1.0.36", - "undici": "^7.2.3", "winston": "^3.11.0", "winston-daily-rotate-file": "^4.7.1", "zod": "^3.22.4" @@ -1056,66 +1055,6 @@ "resolved": "https://registry.npmjs.org/react-is/-/react-is-17.0.2.tgz", "integrity": "sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==" }, - "client/node_modules/vite": { - "version": "5.4.14", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.14.tgz", - "integrity": "sha512-EK5cY7Q1D8JNhSaPKVK4pwBFvaTmZxEnoKXLG/U9gmdDcihQGNzFlgIvaxezFR4glP1LsuiedwMBqCXH3wZccA==", - "dev": true, - "license": "MIT", - "dependencies": { - "esbuild": "^0.21.3", - "postcss": "^8.4.43", - "rollup": "^4.20.0" - }, - "bin": { - "vite": "bin/vite.js" - }, - "engines": { - "node": "^18.0.0 || >=20.0.0" - }, - "funding": { - "url": "https://github.com/vitejs/vite?sponsor=1" - }, - "optionalDependencies": { - "fsevents": "~2.3.3" - }, - "peerDependencies": { - "@types/node": "^18.0.0 || >=20.0.0", - "less": "*", - "lightningcss": "^1.21.0", - "sass": "*", - "sass-embedded": "*", - "stylus": "*", - "sugarss": "*", - "terser": "^5.4.0" - }, - "peerDependenciesMeta": { - "@types/node": { - "optional": true - }, - "less": { - "optional": true - }, - "lightningcss": { - "optional": true - }, - "sass": { - "optional": true - }, - "sass-embedded": { - "optional": true - }, - "stylus": { - "optional": true - }, - "sugarss": { - "optional": true - }, - "terser": { - "optional": true - } - } - }, "node_modules/@aashutoshrathi/word-wrap": { "version": "1.2.6", "resolved": "https://registry.npmjs.org/@aashutoshrathi/word-wrap/-/word-wrap-1.2.6.tgz", @@ -33126,15 +33065,6 @@ "integrity": "sha512-WxONCrssBM8TSPRqN5EmsjVrsv4A8X12J4ArBiiayv3DyyG3ZlIg6yysuuSYdZsVz3TKcTg2fd//Ujd4CHV1iA==", "dev": true }, - "node_modules/undici": { - "version": "7.2.3", - "resolved": "https://registry.npmjs.org/undici/-/undici-7.2.3.tgz", - "integrity": "sha512-2oSLHaDalSt2/O/wHA9M+/ZPAOcU2yrSP/cdBYJ+YxZskiPYDSqHbysLSlD7gq3JMqOoJI5O31RVU3BxX/MnAA==", - "license": "MIT", - "engines": { - "node": ">=20.18.1" - } - }, "node_modules/undici-types": { "version": "5.26.5", "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-5.26.5.tgz", @@ -33752,11 +33682,11 @@ } }, "node_modules/vite": { - "version": "5.4.6", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.6.tgz", - "integrity": "sha512-IeL5f8OO5nylsgzd9tq4qD2QqI0k2CQLGrWD0rCN0EQJZpBK5vJAx0I+GDkMOXxQX/OfFHMuLIx6ddAxGX/k+Q==", + "version": "5.4.14", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.4.14.tgz", + "integrity": "sha512-EK5cY7Q1D8JNhSaPKVK4pwBFvaTmZxEnoKXLG/U9gmdDcihQGNzFlgIvaxezFR4glP1LsuiedwMBqCXH3wZccA==", "dev": true, - "peer": true, + "license": "MIT", "dependencies": { "esbuild": "^0.21.3", "postcss": "^8.4.43", @@ -35101,7 +35031,7 @@ }, "packages/data-provider": { "name": "librechat-data-provider", - "version": "0.7.694", + "version": "0.7.695", "license": "ISC", "dependencies": { "axios": "^1.7.7", diff --git a/packages/data-provider/package.json b/packages/data-provider/package.json index 224b286a261..107504b5c8d 100644 --- a/packages/data-provider/package.json +++ b/packages/data-provider/package.json @@ -1,6 +1,6 @@ { "name": "librechat-data-provider", - "version": "0.7.694", + "version": "0.7.695", "description": "data services for librechat apps", "main": "dist/index.js", "module": "dist/index.es.js", diff --git a/packages/data-provider/src/parsers.ts b/packages/data-provider/src/parsers.ts index 656b54c29e3..71b30449176 100644 --- a/packages/data-provider/src/parsers.ts +++ b/packages/data-provider/src/parsers.ts @@ -272,7 +272,7 @@ export const getResponseSender = (endpointOption: t.TEndpointOption): string => if (endpoint === EModelEndpoint.google) { if (modelLabel) { return modelLabel; - } else if (model && model.includes('gemini')) { + } else if (model && (model.includes('gemini') || model.includes('learnlm'))) { return 'Gemini'; } else if (model && model.includes('code')) { return 'Codey'; diff --git a/packages/data-provider/src/schemas.ts b/packages/data-provider/src/schemas.ts index b20691064a2..8e54e377c31 100644 --- a/packages/data-provider/src/schemas.ts +++ b/packages/data-provider/src/schemas.ts @@ -788,6 +788,25 @@ export const googleSchema = tConversationSchema maxContextTokens: undefined, })); +/** + * TODO: Map the following fields: + - presence_penalty -> presencePenalty + - frequency_penalty -> frequencyPenalty + - stop -> stopSequences + */ +export const googleGenConfigSchema = z + .object({ + maxOutputTokens: coerceNumber.optional(), + temperature: coerceNumber.optional(), + topP: coerceNumber.optional(), + topK: coerceNumber.optional(), + presencePenalty: coerceNumber.optional(), + frequencyPenalty: coerceNumber.optional(), + stopSequences: z.array(z.string()).optional(), + }) + .strip() + .optional(); + export const bingAISchema = tConversationSchema .pick({ jailbreak: true,