From 2a38d310fcd26a49976da6c0d2bfddaa97ff2a3b Mon Sep 17 00:00:00 2001 From: MaxRobinsonTheGreat Date: Wed, 20 Aug 2025 13:08:59 -0500 Subject: [PATCH] refactor models for better modularity, use sweaterdog ollama/local --- src/models/_model_map.js | 88 +++++++++++++++ src/models/gemini.js | 6 +- src/models/{local.js => ollama.js} | 69 ++++++------ src/models/prompter.js | 166 ++++------------------------- 4 files changed, 150 insertions(+), 179 deletions(-) create mode 100644 src/models/_model_map.js rename src/models/{local.js => ollama.js} (60%) diff --git a/src/models/_model_map.js b/src/models/_model_map.js new file mode 100644 index 0000000..10fa893 --- /dev/null +++ b/src/models/_model_map.js @@ -0,0 +1,88 @@ +import { Gemini } from './gemini.js'; +import { GPT } from './gpt.js'; +import { Claude } from './claude.js'; +import { Mistral } from './mistral.js'; +import { ReplicateAPI } from './replicate.js'; +import { Ollama } from './ollama.js'; +import { Novita } from './novita.js'; +import { GroqCloudAPI } from './groq.js'; +import { HuggingFace } from './huggingface.js'; +import { Qwen } from "./qwen.js"; +import { Grok } from "./grok.js"; +import { DeepSeek } from './deepseek.js'; +import { Hyperbolic } from './hyperbolic.js'; +import { GLHF } from './glhf.js'; +import { OpenRouter } from './openrouter.js'; +import { VLLM } from './vllm.js'; + +// Add new models here. +// It maps api prefixes to model classes, eg 'openai/gpt-4o' -> GPT +const apiMap = { + 'openai': GPT, + 'google': Gemini, + 'anthropic': Claude, + 'replicate': ReplicateAPI, + 'ollama': Ollama, + 'mistral': Mistral, + 'groq': GroqCloudAPI, + 'huggingface': HuggingFace, + 'novita': Novita, + 'qwen': Qwen, + 'grok': Grok, + 'deepseek': DeepSeek, + 'hyperbolic': Hyperbolic, + 'glhf': GLHF, + 'openrouter': OpenRouter, + 'vllm': VLLM, +} + +export function selectAPI(profile) { + if (typeof profile === 'string' || profile instanceof String) { + profile = {model: profile}; + } + const api = Object.keys(apiMap).find(key => profile.model.startsWith(key)); + if (api) { + profile.api = api; + } + else { + // backwards compatibility with local->ollama + if (profile.model.includes('local')) { + profile.api = 'ollama'; + profile.model = profile.model.replace('local/', ''); + } + // check for some common models that do not require prefixes + else if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3')) + profile.api = 'openai'; + else if (profile.model.includes('claude')) + profile.api = 'anthropic'; + else if (profile.model.includes('gemini')) + profile.api = "google"; + else if (profile.model.includes('grok')) + profile.api = 'grok'; + else if (profile.model.includes('mistral')) + profile.api = 'mistral'; + else if (profile.model.includes('deepseek')) + profile.api = 'deepseek'; + else if (profile.model.includes('qwen')) + profile.api = 'qwen'; + } + if (!profile.api) { + throw new Error('Unknown model:', profile.model); + } + let model_name = profile.model.replace(profile.api + '/', ''); // remove prefix + profile.model = model_name === "" ? null : model_name; // if model is empty, set to null + return profile; +} + +export function createModel(profile) { + if (!!apiMap[profile.model]) { + // if the model value is an api (instead of a specific model name) + // then set model to null so it uses the default model for that api + profile.model = null; + } + if (!apiMap[profile.api]) { + throw new Error('Unknown api:', profile.api); + } + const model = new apiMap[profile.api](profile.model, profile.url, profile.params); + return model; +} \ No newline at end of file diff --git a/src/models/gemini.js b/src/models/gemini.js index 4d24c93..4e3af14 100644 --- a/src/models/gemini.js +++ b/src/models/gemini.js @@ -142,15 +142,15 @@ export class Gemini { } async embed(text) { - let model; + let model = this.model_name || "text-embedding-004"; if (this.url) { model = this.genAI.getGenerativeModel( - { model: "text-embedding-004" }, + { model }, { baseUrl: this.url } ); } else { model = this.genAI.getGenerativeModel( - { model: "text-embedding-004" } + { model } ); } diff --git a/src/models/local.js b/src/models/ollama.js similarity index 60% rename from src/models/local.js rename to src/models/ollama.js index e51bcf8..064d2ad 100644 --- a/src/models/local.js +++ b/src/models/ollama.js @@ -1,6 +1,6 @@ import { strictFormat } from '../utils/text.js'; -export class Local { +export class Ollama { constructor(model_name, url, params) { this.model_name = model_name; this.params = params; @@ -10,11 +10,9 @@ export class Local { } async sendRequest(turns, systemMessage) { - let model = this.model_name || 'llama3.1'; // Updated to llama3.1, as it is more performant than llama3 + let model = this.model_name || 'sweaterdog/andy-4:micro-q5_k_m'; let messages = strictFormat(turns); messages.unshift({ role: 'system', content: systemMessage }); - - // We'll attempt up to 5 times for models with deepseek-r1-esk reasoning if the tags are mismatched. const maxAttempts = 5; let attempt = 0; let finalRes = null; @@ -24,14 +22,14 @@ export class Local { console.log(`Awaiting local response... (model: ${model}, attempt: ${attempt})`); let res = null; try { - res = await this.send(this.chat_endpoint, { + let apiResponse = await this.send(this.chat_endpoint, { model: model, messages: messages, stream: false, ...(this.params || {}) }); - if (res) { - res = res['message']['content']; + if (apiResponse) { + res = apiResponse['message']['content']; } else { res = 'No response data.'; } @@ -43,36 +41,27 @@ export class Local { console.log(err); res = 'My brain disconnected, try again.'; } - } - // If the model name includes "deepseek-r1" or "Andy-3.5-reasoning", then handle the block. - const hasOpenTag = res.includes(""); - const hasCloseTag = res.includes(""); - - // If there's a partial mismatch, retry to get a complete response. - if ((hasOpenTag && !hasCloseTag)) { - console.warn("Partial block detected. Re-generating..."); - continue; - } - - // If is present but is not, prepend - if (hasCloseTag && !hasOpenTag) { - res = '' + res; - } - // Changed this so if the model reasons, using and but doesn't start the message with , ges prepended to the message so no error occur. - - // If both tags appear, remove them (and everything inside). - if (hasOpenTag && hasCloseTag) { - res = res.replace(/[\s\S]*?<\/think>/g, ''); - } + const hasOpenTag = res.includes(""); + const hasCloseTag = res.includes(""); + if ((hasOpenTag && !hasCloseTag)) { + console.warn("Partial block detected. Re-generating..."); + if (attempt < maxAttempts) continue; + } + if (hasCloseTag && !hasOpenTag) { + res = '' + res; + } + if (hasOpenTag && hasCloseTag) { + res = res.replace(/[\s\S]*?<\/think>/g, '').trim(); + } finalRes = res; - break; // Exit the loop if we got a valid response. + break; } if (finalRes == null) { - console.warn("Could not get a valid block or normal response after max attempts."); + console.warn("Could not get a valid response after max attempts."); finalRes = 'I thought too hard, sorry, try again.'; } return finalRes; @@ -104,4 +93,22 @@ export class Local { } return data; } -} + + async sendVisionRequest(messages, systemMessage, imageBuffer) { + const imageMessages = [...messages]; + imageMessages.push({ + role: "user", + content: [ + { type: "text", text: systemMessage }, + { + type: "image_url", + image_url: { + url: `data:image/jpeg;base64,${imageBuffer.toString('base64')}` + } + } + ] + }); + + return this.sendRequest(imageMessages, systemMessage); + } +} \ No newline at end of file diff --git a/src/models/prompter.js b/src/models/prompter.js index 89d5fe9..a8c4db7 100644 --- a/src/models/prompter.js +++ b/src/models/prompter.js @@ -5,26 +5,10 @@ import { SkillLibrary } from "../agent/library/skill_library.js"; import { stringifyTurns } from '../utils/text.js'; import { getCommand } from '../agent/commands/index.js'; import settings from '../agent/settings.js'; - -import { Gemini } from './gemini.js'; -import { GPT } from './gpt.js'; -import { Claude } from './claude.js'; -import { Mistral } from './mistral.js'; -import { ReplicateAPI } from './replicate.js'; -import { Local } from './local.js'; -import { Novita } from './novita.js'; -import { GroqCloudAPI } from './groq.js'; -import { HuggingFace } from './huggingface.js'; -import { Qwen } from "./qwen.js"; -import { Grok } from "./grok.js"; -import { DeepSeek } from './deepseek.js'; -import { Hyperbolic } from './hyperbolic.js'; -import { GLHF } from './glhf.js'; -import { OpenRouter } from './openrouter.js'; -import { VLLM } from './vllm.js'; import { promises as fs } from 'fs'; import path from 'path'; import { fileURLToPath } from 'url'; +import { selectAPI, createModel } from './_model_map.js'; const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); @@ -66,70 +50,46 @@ export class Prompter { this.last_prompt_time = 0; this.awaiting_coding = false; - // try to get "max_tokens" parameter, else null + // for backwards compatibility, move max_tokens to params let max_tokens = null; if (this.profile.max_tokens) max_tokens = this.profile.max_tokens; - let chat_model_profile = this._selectAPI(this.profile.model); - this.chat_model = this._createModel(chat_model_profile); + let chat_model_profile = selectAPI(this.profile.model); + this.chat_model = createModel(chat_model_profile); if (this.profile.code_model) { - let code_model_profile = this._selectAPI(this.profile.code_model); - this.code_model = this._createModel(code_model_profile); + let code_model_profile = selectAPI(this.profile.code_model); + this.code_model = createModel(code_model_profile); } else { this.code_model = this.chat_model; } if (this.profile.vision_model) { - let vision_model_profile = this._selectAPI(this.profile.vision_model); - this.vision_model = this._createModel(vision_model_profile); + let vision_model_profile = selectAPI(this.profile.vision_model); + this.vision_model = createModel(vision_model_profile); } else { this.vision_model = this.chat_model; } - let embedding = this.profile.embedding; - if (embedding === undefined) { - if (chat_model_profile.api !== 'ollama') - embedding = {api: chat_model_profile.api}; - else - embedding = {api: 'none'}; - } - else if (typeof embedding === 'string' || embedding instanceof String) - embedding = {api: embedding}; - - console.log('Using embedding settings:', embedding); - - try { - if (embedding.api === 'google') - this.embedding_model = new Gemini(embedding.model, embedding.url); - else if (embedding.api === 'openai') - this.embedding_model = new GPT(embedding.model, embedding.url); - else if (embedding.api === 'replicate') - this.embedding_model = new ReplicateAPI(embedding.model, embedding.url); - else if (embedding.api === 'ollama') - this.embedding_model = new Local(embedding.model, embedding.url); - else if (embedding.api === 'qwen') - this.embedding_model = new Qwen(embedding.model, embedding.url); - else if (embedding.api === 'mistral') - this.embedding_model = new Mistral(embedding.model, embedding.url); - else if (embedding.api === 'huggingface') - this.embedding_model = new HuggingFace(embedding.model, embedding.url); - else if (embedding.api === 'novita') - this.embedding_model = new Novita(embedding.model, embedding.url); - else { - this.embedding_model = null; - let embedding_name = embedding ? embedding.api : '[NOT SPECIFIED]' - console.warn('Unsupported embedding: ' + embedding_name + '. Using word-overlap instead, expect reduced performance. Recommend using a supported embedding model. See Readme.'); + + let embedding_model_profile = null; + if (this.profile.embedding) { + try { + embedding_model_profile = selectAPI(this.profile.embedding); + } catch (e) { + embedding_model_profile = null; } } - catch (err) { - console.warn('Warning: Failed to initialize embedding model:', err.message); - console.log('Continuing anyway, using word-overlap instead.'); - this.embedding_model = null; + if (embedding_model_profile) { + this.embedding_model = createModel(embedding_model_profile); } + else { + this.embedding_model = createModel({api: chat_model_profile.api}); + } + this.skill_libary = new SkillLibrary(agent, this.embedding_model); mkdirSync(`./bots/${name}`, { recursive: true }); writeFileSync(`./bots/${name}/last_profile.json`, JSON.stringify(this.profile, null, 4), (err) => { @@ -140,88 +100,6 @@ export class Prompter { }); } - _selectAPI(profile) { - if (typeof profile === 'string' || profile instanceof String) { - profile = {model: profile}; - } - if (!profile.api) { - if (profile.model.includes('openrouter/')) - profile.api = 'openrouter'; // must do first because shares names with other models - else if (profile.model.includes('ollama/')) - profile.api = 'ollama'; // also must do early because shares names with other models - else if (profile.model.includes('gemini')) - profile.api = 'google'; - else if (profile.model.includes('vllm/')) - profile.api = 'vllm'; - else if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3')) - profile.api = 'openai'; - else if (profile.model.includes('claude')) - profile.api = 'anthropic'; - else if (profile.model.includes('huggingface/')) - profile.api = "huggingface"; - else if (profile.model.includes('replicate/')) - profile.api = 'replicate'; - else if (profile.model.includes('mistralai/') || profile.model.includes("mistral/")) - model_profile.api = 'mistral'; - else if (profile.model.includes("groq/") || profile.model.includes("groqcloud/")) - profile.api = 'groq'; - else if (profile.model.includes("glhf/")) - profile.api = 'glhf'; - else if (profile.model.includes("hyperbolic/")) - profile.api = 'hyperbolic'; - else if (profile.model.includes('novita/')) - profile.api = 'novita'; - else if (profile.model.includes('qwen')) - profile.api = 'qwen'; - else if (profile.model.includes('grok')) - profile.api = 'xai'; - else if (profile.model.includes('deepseek')) - profile.api = 'deepseek'; - else if (profile.model.includes('mistral')) - profile.api = 'mistral'; - else - throw new Error('Unknown model:', profile.model); - } - return profile; - } - _createModel(profile) { - let model = null; - if (profile.api === 'google') - model = new Gemini(profile.model, profile.url, profile.params); - else if (profile.api === 'openai') - model = new GPT(profile.model, profile.url, profile.params); - else if (profile.api === 'anthropic') - model = new Claude(profile.model, profile.url, profile.params); - else if (profile.api === 'replicate') - model = new ReplicateAPI(profile.model.replace('replicate/', ''), profile.url, profile.params); - else if (profile.api === 'ollama') - model = new Local(profile.model.replace('ollama/', ''), profile.url, profile.params); - else if (profile.api === 'mistral') - model = new Mistral(profile.model, profile.url, profile.params); - else if (profile.api === 'groq') - model = new GroqCloudAPI(profile.model.replace('groq/', '').replace('groqcloud/', ''), profile.url, profile.params); - else if (profile.api === 'huggingface') - model = new HuggingFace(profile.model, profile.url, profile.params); - else if (profile.api === 'glhf') - model = new GLHF(profile.model.replace('glhf/', ''), profile.url, profile.params); - else if (profile.api === 'hyperbolic') - model = new Hyperbolic(profile.model.replace('hyperbolic/', ''), profile.url, profile.params); - else if (profile.api === 'novita') - model = new Novita(profile.model.replace('novita/', ''), profile.url, profile.params); - else if (profile.api === 'qwen') - model = new Qwen(profile.model, profile.url, profile.params); - else if (profile.api === 'xai') - model = new Grok(profile.model, profile.url, profile.params); - else if (profile.api === 'deepseek') - model = new DeepSeek(profile.model, profile.url, profile.params); - else if (profile.api === 'openrouter') - model = new OpenRouter(profile.model.replace('openrouter/', ''), profile.url, profile.params); - else if (profile.api === 'vllm') - model = new VLLM(profile.model.replace('vllm/', ''), profile.url, profile.params); - else - throw new Error('Unknown API:', profile.api); - return model; - } getName() { return this.profile.name; } @@ -482,6 +360,4 @@ export class Prompter { logFile = path.join(logDir, logFile); await fs.appendFile(logFile, String(logEntry), 'utf-8'); } - - }