From 3fe7a1542b4d840aa1c23d5488654bbab5855bd6 Mon Sep 17 00:00:00 2001 From: FateUnix29 Date: Sun, 25 Aug 2024 13:16:32 -0700 Subject: [PATCH] Try and add support for new models. --- src/agent/prompter.js | 18 +++++++-- src/models/groq.js | 86 ++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 99 insertions(+), 5 deletions(-) diff --git a/src/agent/prompter.js b/src/agent/prompter.js index d9db2f8..62e44ce 100644 --- a/src/agent/prompter.js +++ b/src/agent/prompter.js @@ -10,7 +10,9 @@ import { GPT } from '../models/gpt.js'; import { Claude } from '../models/claude.js'; import { ReplicateAPI } from '../models/replicate.js'; import { Local } from '../models/local.js'; -import { Mixtral } from '../models/groq.js'; +import { Mixtral_Groq } from '../models/groq.js'; +import { LLama3_70b_Groq } from '../models/groq.js'; +import { Gemma2_9b_Groq } from '../models/groq.js'; export class Prompter { @@ -33,7 +35,11 @@ export class Prompter { else if (chat.model.includes('meta/') || chat.model.includes('mistralai/') || chat.model.includes('replicate/')) chat.api = 'replicate'; else if (chat.model.includes('mixtral')) - chat.api = 'groq'; + chat.api = 'groq_mixtral'; + else if (chat.model.includes('llama3-70b')) + chat.api = 'groq_llama3_70b'; + else if (chat.model.includes('gemma2-9b')) + chat.api = 'groq_gemma2_9b'; else chat.api = 'ollama'; } @@ -50,8 +56,12 @@ export class Prompter { this.chat_model = new ReplicateAPI(chat.model, chat.url); else if (chat.api == 'ollama') this.chat_model = new Local(chat.model, chat.url); - else if (chat.api == 'groq') - this.chat_model = new Mixtral(chat.model, chat.url) + else if (chat.api == 'groq_mixtral') + this.chat_model = new Mixtral_Groq(chat.model, chat.url) + else if (chat.api == 'groq_llama3_70b') + this.chat_model = new LLama3_70b_Groq(chat.model, chat.url) + else if (chat.api == 'groq_gemma2_9b') + this.chat_model = new Gemma2_9b_Groq(chat.model, chat.url) else throw new Error('Unknown API:', api); diff --git a/src/models/groq.js b/src/models/groq.js index 1ca6ece..da4d515 100644 --- a/src/models/groq.js +++ b/src/models/groq.js @@ -1,7 +1,7 @@ import Groq from 'groq-sdk' import { getKey } from '../utils/keys.js'; -export class Mixtral { +export class Mixtral_Groq { constructor(model_name, url) { this.model_name = model_name; this.url = url; @@ -38,6 +38,90 @@ export class Mixtral { return res; } + async embed(text) { + console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text); + } +} + +export class LLama3_70b_Groq { + constructor(model_name, url) { + this.model_name = model_name; + this.url = url; + this.groq = new Groq({ apiKey: getKey('GROQCLOUD_API_KEY')}); + } + + async sendRequest(turns, systemMessage, stop_seq=null) { + let messages = [{"role": "system", "content": systemMessage}].concat(turns); + let res = null; + try { + console.log("Awaiting Groq response..."); + let completion = await this.groq.chat.completions.create({ + "messages": messages, + "model": this.model_name || "llama3-70b-8192", + "temperature": 0.2, + "max_tokens": 8192, // maximum token limit + "top_p": 1, + "stream": true, + "stop": stop_seq // "***" + }); + + let temp_res = ""; + for await (const chunk of completion) { + temp_res += chunk.choices[0]?.delta?.content || ''; + } + + res = temp_res; + + } + catch(err) { + console.log(err); + res = "My brain just kinda stopped working. Try again."; + } + return res; + } + + async embed(text) { + console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text); + } +} + +export class Gemma2_9b_Groq { + constructor(model_name, url) { + this.model_name = model_name; + this.url = url; + this.groq = new Groq({ apiKey: getKey('GROQCLOUD_API_KEY')}); + } + + async sendRequest(turns, systemMessage, stop_seq=null) { + let messages = [{"role": "system", "content": systemMessage}].concat(turns); + let res = null; + try { + console.log("Awaiting Groq response..."); + let completion = await this.groq.chat.completions.create({ + "messages": messages, + "model": this.model_name || "gemma2-9b-it", + "temperature": 0.2, + "max_tokens": 8192, // maximum token limit + "top_p": 1, + "stream": true, + "stop": stop_seq // "***" + }); + + let temp_res = ""; + for await (const chunk of completion) { + temp_res += chunk.choices[0]?.delta?.content || ''; + } + + res = temp_res; + + } + catch(err) { + console.log(err); + res = "My brain just kinda stopped working. Try again."; + } + return res; + } + async embed(text) { console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text); }