From ec6f4f709876e2c6d1e376b00885f020e30fa381 Mon Sep 17 00:00:00 2001 From: Sweaterdog Date: Tue, 28 Jan 2025 13:43:33 -0800 Subject: [PATCH] Update groq.js Fixed small error that would endlessly retry groqcloud response if Deepseek-R1 was chosen --- src/models/groq.js | 111 ++++++++++++++++++++++++++++----------------- 1 file changed, 69 insertions(+), 42 deletions(-) diff --git a/src/models/groq.js b/src/models/groq.js index 87ec163..08d2e1d 100644 --- a/src/models/groq.js +++ b/src/models/groq.js @@ -1,88 +1,115 @@ +// groq.js + import Groq from 'groq-sdk'; import { getKey } from '../utils/keys.js'; +import { log } from '../../logger.js'; -// Umbrella class for Mixtral, LLama, Gemma... +/** + * Umbrella class for Mixtral, LLama, Gemma... + */ export class GroqCloudAPI { - constructor(model_name, url, max_tokens=16384) { + constructor(model_name, url, max_tokens = 16384) { this.model_name = model_name; this.url = url; this.max_tokens = max_tokens; - // ReplicateAPI theft :3 + // Groq Cloud doesn't support custom URLs; warn if provided if (this.url) { console.warn("Groq Cloud has no implementation for custom URLs. Ignoring provided URL."); } + + // Initialize Groq SDK with the API key this.groq = new Groq({ apiKey: getKey('GROQCLOUD_API_KEY') }); } - async sendRequest(turns, systemMessage, stop_seq=null) { - // We'll do up to 5 attempts for partial mismatch if - // the model name includes "deepseek-r1". + /** + * Sends a chat completion request to the Groq Cloud endpoint. + * + * @param {Array} turns - An array of message objects, e.g., [{role: 'user', content: 'Hi'}]. + * @param {string} systemMessage - The system prompt or instruction. + * @param {string} stop_seq - A string that represents a stopping sequence, default '***'. + * @returns {Promise} - The content of the model's reply. + */ + async sendRequest(turns, systemMessage, stop_seq = '***') { + // Maximum number of attempts to handle partial tag mismatches 5 is a good value, I guess const maxAttempts = 5; let attempt = 0; let finalRes = null; - // Prepare the message array - let messages = [{ role: "system", content: systemMessage }].concat(turns); + // Prepare the input messages by prepending the system message + const messages = [{ role: 'system', content: systemMessage }, ...turns]; + console.log('Messages:', messages); while (attempt < maxAttempts) { attempt++; - console.log(`Awaiting Groq response... (attempt: ${attempt}/${maxAttempts})`); + console.log(`Awaiting Groq response... (model: ${this.model_name}, attempt: ${attempt})`); + + let res = null; - // Collect the streaming response - let temp_res = ""; try { - // Create the chat completion stream - let completion = await this.groq.chat.completions.create({ + // Create the chat completion request + const completion = await this.groq.chat.completions.create({ messages: messages, model: this.model_name || "mixtral-8x7b-32768", temperature: 0.2, - max_tokens: this.max_tokens, + max_tokens: this.max_tokens, top_p: 1, - stream: true, - stop: stop_seq // e.g. "***" + stream: false, + stop: stop_seq // "***" }); - // Read each streamed chunk - for await (const chunk of completion) { - temp_res += chunk.choices[0]?.delta?.content || ''; - } + // Extract the content from the response + res = completion?.choices?.[0]?.message?.content || ''; + console.log('Received response from Groq.'); } catch (err) { - console.error("Error while streaming from Groq:", err); - temp_res = "My brain just kinda stopped working. Try again."; - // We won't retry partial mismatch if a genuine error occurred here - finalRes = temp_res; - break; + // Handle context length exceeded by retrying with shorter context + if ( + err.message.toLowerCase().includes('context length') && + turns.length > 1 + ) { + console.log('Context length exceeded, trying again with a shorter context.'); + // Remove the earliest user turn and retry + return await this.sendRequest(turns.slice(1), systemMessage, stop_seq); + } else { + // Log other errors and return fallback message + console.log(err); + res = 'My brain disconnected, try again.'; + } } - // If the model name includes "deepseek-r1", apply logic + // If the model name includes "deepseek-r1", handle tags if (this.model_name && this.model_name.toLowerCase().includes("deepseek-r1")) { - const hasOpen = temp_res.includes(""); - const hasClose = temp_res.includes(""); + const hasOpenTag = res.includes(""); + const hasCloseTag = res.includes(""); - // If partial mismatch, retry - if ((hasOpen && !hasClose) || (!hasOpen && hasClose)) { - console.warn("Partial block detected. Retrying..."); - continue; + // Check for partial tag mismatches + if ((hasOpenTag && !hasCloseTag)) { + console.warn("Partial block detected. Re-generating Groq request..."); + // Retry the request by continuing the loop + continue; } - // If both and appear, remove the entire block - if (hasOpen && hasClose) { - // Remove everything from to - temp_res = temp_res.replace(/[\s\S]*?<\/think>/g, '').trim(); + // If is present but is not, prepend + if (hasCloseTag && !hasOpenTag) { + res = '' + res; } + // Trim the block from the response + res = res.replace(/[\s\S]*?<\/think>/g, '').trim(); } - // We either do not have deepseek-r1 or we have a correct scenario - finalRes = temp_res; - break; + // Assign the processed response and exit the loop + finalRes = res; + break; // Stop retrying } - // If, after max attempts, we never set finalRes (e.g., partial mismatch each time) + // If after all attempts, finalRes is still null, assign a fallback if (finalRes == null) { - console.warn("Could not obtain a valid or matched response after max attempts."); - finalRes = "Response incomplete, please try again."; + console.warn("Could not obtain a valid block or normal response after max attempts."); + finalRes = 'Response incomplete, please try again.'; } + + finalRes = finalRes.replace(/<\|separator\|>/g, '*no response*'); + return finalRes; }