diff --git a/src/models/groq.js b/src/models/groq.js index e17f13d..87ec163 100644 --- a/src/models/groq.js +++ b/src/models/groq.js @@ -1,13 +1,13 @@ -import Groq from 'groq-sdk' +import Groq from 'groq-sdk'; import { getKey } from '../utils/keys.js'; - // Umbrella class for Mixtral, LLama, Gemma... export class GroqCloudAPI { constructor(model_name, url, max_tokens=16384) { this.model_name = model_name; this.url = url; this.max_tokens = max_tokens; + // ReplicateAPI theft :3 if (this.url) { console.warn("Groq Cloud has no implementation for custom URLs. Ignoring provided URL."); @@ -16,36 +16,77 @@ export class GroqCloudAPI { } async sendRequest(turns, systemMessage, stop_seq=null) { - let messages = [{"role": "system", "content": systemMessage}].concat(turns); - let res = null; - try { - console.log("Awaiting Groq response..."); - let completion = await this.groq.chat.completions.create({ - "messages": messages, - "model": this.model_name || "mixtral-8x7b-32768", - "temperature": 0.2, - "max_tokens": this.max_tokens, // maximum token limit, differs from model to model - "top_p": 1, - "stream": true, - "stop": stop_seq // "***" - }); + // We'll do up to 5 attempts for partial mismatch if + // the model name includes "deepseek-r1". + const maxAttempts = 5; + let attempt = 0; + let finalRes = null; + // Prepare the message array + let messages = [{ role: "system", content: systemMessage }].concat(turns); + + while (attempt < maxAttempts) { + attempt++; + console.log(`Awaiting Groq response... (attempt: ${attempt}/${maxAttempts})`); + + // Collect the streaming response let temp_res = ""; - for await (const chunk of completion) { - temp_res += chunk.choices[0]?.delta?.content || ''; + try { + // Create the chat completion stream + let completion = await this.groq.chat.completions.create({ + messages: messages, + model: this.model_name || "mixtral-8x7b-32768", + temperature: 0.2, + max_tokens: this.max_tokens, + top_p: 1, + stream: true, + stop: stop_seq // e.g. "***" + }); + + // Read each streamed chunk + for await (const chunk of completion) { + temp_res += chunk.choices[0]?.delta?.content || ''; + } + } catch (err) { + console.error("Error while streaming from Groq:", err); + temp_res = "My brain just kinda stopped working. Try again."; + // We won't retry partial mismatch if a genuine error occurred here + finalRes = temp_res; + break; } - res = temp_res; + // If the model name includes "deepseek-r1", apply logic + if (this.model_name && this.model_name.toLowerCase().includes("deepseek-r1")) { + const hasOpen = temp_res.includes(""); + const hasClose = temp_res.includes(""); + // If partial mismatch, retry + if ((hasOpen && !hasClose) || (!hasOpen && hasClose)) { + console.warn("Partial block detected. Retrying..."); + continue; + } + + // If both and appear, remove the entire block + if (hasOpen && hasClose) { + // Remove everything from to + temp_res = temp_res.replace(/[\s\S]*?<\/think>/g, '').trim(); + } + } + + // We either do not have deepseek-r1 or we have a correct scenario + finalRes = temp_res; + break; } - catch(err) { - console.log(err); - res = "My brain just kinda stopped working. Try again."; + + // If, after max attempts, we never set finalRes (e.g., partial mismatch each time) + if (finalRes == null) { + console.warn("Could not obtain a valid or matched response after max attempts."); + finalRes = "Response incomplete, please try again."; } - return res; + return finalRes; } async embed(text) { - console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text); + console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text); } -} \ No newline at end of file +}