Try and add support for new models.

This commit is contained in:
FateUnix29 2024-08-25 13:16:32 -07:00
parent 99a7158382
commit 3fe7a1542b
2 changed files with 99 additions and 5 deletions

View file

@ -10,7 +10,9 @@ import { GPT } from '../models/gpt.js';
import { Claude } from '../models/claude.js';
import { ReplicateAPI } from '../models/replicate.js';
import { Local } from '../models/local.js';
import { Mixtral } from '../models/groq.js';
import { Mixtral_Groq } from '../models/groq.js';
import { LLama3_70b_Groq } from '../models/groq.js';
import { Gemma2_9b_Groq } from '../models/groq.js';
export class Prompter {
@ -33,7 +35,11 @@ export class Prompter {
else if (chat.model.includes('meta/') || chat.model.includes('mistralai/') || chat.model.includes('replicate/'))
chat.api = 'replicate';
else if (chat.model.includes('mixtral'))
chat.api = 'groq';
chat.api = 'groq_mixtral';
else if (chat.model.includes('llama3-70b'))
chat.api = 'groq_llama3_70b';
else if (chat.model.includes('gemma2-9b'))
chat.api = 'groq_gemma2_9b';
else
chat.api = 'ollama';
}
@ -50,8 +56,12 @@ export class Prompter {
this.chat_model = new ReplicateAPI(chat.model, chat.url);
else if (chat.api == 'ollama')
this.chat_model = new Local(chat.model, chat.url);
else if (chat.api == 'groq')
this.chat_model = new Mixtral(chat.model, chat.url)
else if (chat.api == 'groq_mixtral')
this.chat_model = new Mixtral_Groq(chat.model, chat.url)
else if (chat.api == 'groq_llama3_70b')
this.chat_model = new LLama3_70b_Groq(chat.model, chat.url)
else if (chat.api == 'groq_gemma2_9b')
this.chat_model = new Gemma2_9b_Groq(chat.model, chat.url)
else
throw new Error('Unknown API:', api);

View file

@ -1,7 +1,7 @@
import Groq from 'groq-sdk'
import { getKey } from '../utils/keys.js';
export class Mixtral {
export class Mixtral_Groq {
constructor(model_name, url) {
this.model_name = model_name;
this.url = url;
@ -38,6 +38,90 @@ export class Mixtral {
return res;
}
async embed(text) {
console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text);
}
}
export class LLama3_70b_Groq {
constructor(model_name, url) {
this.model_name = model_name;
this.url = url;
this.groq = new Groq({ apiKey: getKey('GROQCLOUD_API_KEY')});
}
async sendRequest(turns, systemMessage, stop_seq=null) {
let messages = [{"role": "system", "content": systemMessage}].concat(turns);
let res = null;
try {
console.log("Awaiting Groq response...");
let completion = await this.groq.chat.completions.create({
"messages": messages,
"model": this.model_name || "llama3-70b-8192",
"temperature": 0.2,
"max_tokens": 8192, // maximum token limit
"top_p": 1,
"stream": true,
"stop": stop_seq // "***"
});
let temp_res = "";
for await (const chunk of completion) {
temp_res += chunk.choices[0]?.delta?.content || '';
}
res = temp_res;
}
catch(err) {
console.log(err);
res = "My brain just kinda stopped working. Try again.";
}
return res;
}
async embed(text) {
console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text);
}
}
export class Gemma2_9b_Groq {
constructor(model_name, url) {
this.model_name = model_name;
this.url = url;
this.groq = new Groq({ apiKey: getKey('GROQCLOUD_API_KEY')});
}
async sendRequest(turns, systemMessage, stop_seq=null) {
let messages = [{"role": "system", "content": systemMessage}].concat(turns);
let res = null;
try {
console.log("Awaiting Groq response...");
let completion = await this.groq.chat.completions.create({
"messages": messages,
"model": this.model_name || "gemma2-9b-it",
"temperature": 0.2,
"max_tokens": 8192, // maximum token limit
"top_p": 1,
"stream": true,
"stop": stop_seq // "***"
});
let temp_res = "";
for await (const chunk of completion) {
temp_res += chunk.choices[0]?.delta?.content || '';
}
res = temp_res;
}
catch(err) {
console.log(err);
res = "My brain just kinda stopped working. Try again.";
}
return res;
}
async embed(text) {
console.log("There is no support for embeddings in Groq support. However, the following text was provided: " + text);
}