Update groq.js

Added deepseek-r1 support
This commit is contained in:
Sweaterdog 2025-01-27 16:13:09 -08:00 committed by GitHub
parent b907e77609
commit 2b3ca165e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,13 +1,13 @@
import Groq from 'groq-sdk' import Groq from 'groq-sdk';
import { getKey } from '../utils/keys.js'; import { getKey } from '../utils/keys.js';
// Umbrella class for Mixtral, LLama, Gemma... // Umbrella class for Mixtral, LLama, Gemma...
export class GroqCloudAPI { export class GroqCloudAPI {
constructor(model_name, url, max_tokens=16384) { constructor(model_name, url, max_tokens=16384) {
this.model_name = model_name; this.model_name = model_name;
this.url = url; this.url = url;
this.max_tokens = max_tokens; this.max_tokens = max_tokens;
// ReplicateAPI theft :3 // ReplicateAPI theft :3
if (this.url) { if (this.url) {
console.warn("Groq Cloud has no implementation for custom URLs. Ignoring provided URL."); console.warn("Groq Cloud has no implementation for custom URLs. Ignoring provided URL.");
@ -16,33 +16,74 @@ export class GroqCloudAPI {
} }
async sendRequest(turns, systemMessage, stop_seq=null) { async sendRequest(turns, systemMessage, stop_seq=null) {
let messages = [{"role": "system", "content": systemMessage}].concat(turns); // We'll do up to 5 attempts for partial <think> mismatch if
let res = null; // the model name includes "deepseek-r1".
const maxAttempts = 5;
let attempt = 0;
let finalRes = null;
// Prepare the message array
let messages = [{ role: "system", content: systemMessage }].concat(turns);
while (attempt < maxAttempts) {
attempt++;
console.log(`Awaiting Groq response... (attempt: ${attempt}/${maxAttempts})`);
// Collect the streaming response
let temp_res = "";
try { try {
console.log("Awaiting Groq response..."); // Create the chat completion stream
let completion = await this.groq.chat.completions.create({ let completion = await this.groq.chat.completions.create({
"messages": messages, messages: messages,
"model": this.model_name || "mixtral-8x7b-32768", model: this.model_name || "mixtral-8x7b-32768",
"temperature": 0.2, temperature: 0.2,
"max_tokens": this.max_tokens, // maximum token limit, differs from model to model max_tokens: this.max_tokens,
"top_p": 1, top_p: 1,
"stream": true, stream: true,
"stop": stop_seq // "***" stop: stop_seq // e.g. "***"
}); });
let temp_res = ""; // Read each streamed chunk
for await (const chunk of completion) { for await (const chunk of completion) {
temp_res += chunk.choices[0]?.delta?.content || ''; temp_res += chunk.choices[0]?.delta?.content || '';
} }
} catch (err) {
res = temp_res; console.error("Error while streaming from Groq:", err);
temp_res = "My brain just kinda stopped working. Try again.";
// We won't retry partial mismatch if a genuine error occurred here
finalRes = temp_res;
break;
} }
catch(err) {
console.log(err); // If the model name includes "deepseek-r1", apply <think> logic
res = "My brain just kinda stopped working. Try again."; if (this.model_name && this.model_name.toLowerCase().includes("deepseek-r1")) {
const hasOpen = temp_res.includes("<think>");
const hasClose = temp_res.includes("</think>");
// If partial mismatch, retry
if ((hasOpen && !hasClose) || (!hasOpen && hasClose)) {
console.warn("Partial <think> block detected. Retrying...");
continue;
} }
return res;
// If both <think> and </think> appear, remove the entire block
if (hasOpen && hasClose) {
// Remove everything from <think> to </think>
temp_res = temp_res.replace(/<think>[\s\S]*?<\/think>/g, '').trim();
}
}
// We either do not have deepseek-r1 or we have a correct <think> scenario
finalRes = temp_res;
break;
}
// If, after max attempts, we never set finalRes (e.g., partial mismatch each time)
if (finalRes == null) {
console.warn("Could not obtain a valid or matched <think> response after max attempts.");
finalRes = "Response incomplete, please try again.";
}
return finalRes;
} }
async embed(text) { async embed(text) {