From 14eff85120063697df43d43b614dc0c52e81a717 Mon Sep 17 00:00:00 2001 From: MaxRobinsonTheGreat Date: Wed, 20 Aug 2025 18:04:00 -0500 Subject: [PATCH] dynamically load models --- src/models/_model_map.js | 123 +++++++++++++++++++------------------- src/models/claude.js | 1 + src/models/deepseek.js | 1 + src/models/gemini.js | 1 + src/models/glhf.js | 1 + src/models/gpt.js | 11 ++-- src/models/grok.js | 1 + src/models/groq.js | 2 +- src/models/huggingface.js | 1 + src/models/hyperbolic.js | 1 + src/models/mistral.js | 1 + src/models/novita.js | 1 + src/models/ollama.js | 1 + src/models/openrouter.js | 1 + src/models/qwen.js | 1 + src/models/replicate.js | 1 + src/models/vllm.js | 6 +- 17 files changed, 85 insertions(+), 70 deletions(-) diff --git a/src/models/_model_map.js b/src/models/_model_map.js index 10fa893..be43893 100644 --- a/src/models/_model_map.js +++ b/src/models/_model_map.js @@ -1,73 +1,74 @@ -import { Gemini } from './gemini.js'; -import { GPT } from './gpt.js'; -import { Claude } from './claude.js'; -import { Mistral } from './mistral.js'; -import { ReplicateAPI } from './replicate.js'; -import { Ollama } from './ollama.js'; -import { Novita } from './novita.js'; -import { GroqCloudAPI } from './groq.js'; -import { HuggingFace } from './huggingface.js'; -import { Qwen } from "./qwen.js"; -import { Grok } from "./grok.js"; -import { DeepSeek } from './deepseek.js'; -import { Hyperbolic } from './hyperbolic.js'; -import { GLHF } from './glhf.js'; -import { OpenRouter } from './openrouter.js'; -import { VLLM } from './vllm.js'; +import { promises as fs } from 'fs'; +import path from 'path'; +import { fileURLToPath, pathToFileURL } from 'url'; -// Add new models here. -// It maps api prefixes to model classes, eg 'openai/gpt-4o' -> GPT -const apiMap = { - 'openai': GPT, - 'google': Gemini, - 'anthropic': Claude, - 'replicate': ReplicateAPI, - 'ollama': Ollama, - 'mistral': Mistral, - 'groq': GroqCloudAPI, - 'huggingface': HuggingFace, - 'novita': Novita, - 'qwen': Qwen, - 'grok': Grok, - 'deepseek': DeepSeek, - 'hyperbolic': Hyperbolic, - 'glhf': GLHF, - 'openrouter': OpenRouter, - 'vllm': VLLM, -} +const __filename = fileURLToPath(import.meta.url); +const __dirname = path.dirname(__filename); + +// Dynamically discover model classes in this directory. +// Each model class must export a static `prefix` string. +const apiMap = await (async () => { + const map = {}; + const files = (await fs.readdir(__dirname)) + .filter(f => f.endsWith('.js') && f !== '_model_map.js' && f !== 'prompter.js'); + for (const file of files) { + try { + const moduleUrl = pathToFileURL(path.join(__dirname, file)).href; + const mod = await import(moduleUrl); + for (const exported of Object.values(mod)) { + if (typeof exported === 'function' && Object.prototype.hasOwnProperty.call(exported, 'prefix')) { + const prefix = exported.prefix; + if (typeof prefix === 'string' && prefix.length > 0) { + map[prefix] = exported; + } + } + } + } catch (e) { + console.warn('Failed to load model module:', file, e?.message || e); + } + } + return map; +})(); export function selectAPI(profile) { if (typeof profile === 'string' || profile instanceof String) { profile = {model: profile}; } - const api = Object.keys(apiMap).find(key => profile.model.startsWith(key)); - if (api) { - profile.api = api; - } - else { - // backwards compatibility with local->ollama - if (profile.model.includes('local')) { - profile.api = 'ollama'; - profile.model = profile.model.replace('local/', ''); + // backwards compatibility with local->ollama + if (profile.api?.includes('local') || profile.model?.includes('local')) { + profile.api = 'ollama'; + if (profile.model) { + profile.model = profile.model.replace('local', 'ollama'); } - // check for some common models that do not require prefixes - else if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3')) - profile.api = 'openai'; - else if (profile.model.includes('claude')) - profile.api = 'anthropic'; - else if (profile.model.includes('gemini')) - profile.api = "google"; - else if (profile.model.includes('grok')) - profile.api = 'grok'; - else if (profile.model.includes('mistral')) - profile.api = 'mistral'; - else if (profile.model.includes('deepseek')) - profile.api = 'deepseek'; - else if (profile.model.includes('qwen')) - profile.api = 'qwen'; } if (!profile.api) { - throw new Error('Unknown model:', profile.model); + const api = Object.keys(apiMap).find(key => profile.model?.startsWith(key)); + if (api) { + profile.api = api; + } + else { + // check for some common models that do not require prefixes + if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3')) + profile.api = 'openai'; + else if (profile.model.includes('claude')) + profile.api = 'anthropic'; + else if (profile.model.includes('gemini')) + profile.api = "google"; + else if (profile.model.includes('grok')) + profile.api = 'grok'; + else if (profile.model.includes('mistral')) + profile.api = 'mistral'; + else if (profile.model.includes('deepseek')) + profile.api = 'deepseek'; + else if (profile.model.includes('qwen')) + profile.api = 'qwen'; + } + if (!profile.api) { + throw new Error('Unknown model:', profile.model); + } + } + if (!apiMap[profile.api]) { + throw new Error('Unknown api:', profile.api); } let model_name = profile.model.replace(profile.api + '/', ''); // remove prefix profile.model = model_name === "" ? null : model_name; // if model is empty, set to null diff --git a/src/models/claude.js b/src/models/claude.js index d6e48bc..c42d2e6 100644 --- a/src/models/claude.js +++ b/src/models/claude.js @@ -3,6 +3,7 @@ import { strictFormat } from '../utils/text.js'; import { getKey } from '../utils/keys.js'; export class Claude { + static prefix = 'anthropic'; constructor(model_name, url, params) { this.model_name = model_name; this.params = params || {}; diff --git a/src/models/deepseek.js b/src/models/deepseek.js index da98ba2..5596fa8 100644 --- a/src/models/deepseek.js +++ b/src/models/deepseek.js @@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js'; import { strictFormat } from '../utils/text.js'; export class DeepSeek { + static prefix = 'deepseek'; constructor(model_name, url, params) { this.model_name = model_name; this.params = params; diff --git a/src/models/gemini.js b/src/models/gemini.js index 4e3af14..75a20e0 100644 --- a/src/models/gemini.js +++ b/src/models/gemini.js @@ -3,6 +3,7 @@ import { toSinglePrompt, strictFormat } from '../utils/text.js'; import { getKey } from '../utils/keys.js'; export class Gemini { + static prefix = 'google'; constructor(model_name, url, params) { this.model_name = model_name; this.params = params; diff --git a/src/models/glhf.js b/src/models/glhf.js index d41b843..b237c8d 100644 --- a/src/models/glhf.js +++ b/src/models/glhf.js @@ -2,6 +2,7 @@ import OpenAIApi from 'openai'; import { getKey } from '../utils/keys.js'; export class GLHF { + static prefix = 'glhf'; constructor(model_name, url) { this.model_name = model_name; const apiKey = getKey('GHLF_API_KEY'); diff --git a/src/models/gpt.js b/src/models/gpt.js index e8e5c5c..ea7d600 100644 --- a/src/models/gpt.js +++ b/src/models/gpt.js @@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js'; import { strictFormat } from '../utils/text.js'; export class GPT { + static prefix = 'openai'; constructor(model_name, url, params) { this.model_name = model_name; this.params = params; @@ -22,20 +23,21 @@ export class GPT { async sendRequest(turns, systemMessage, stop_seq='***') { let messages = [{'role': 'system', 'content': systemMessage}].concat(turns); messages = strictFormat(messages); + let model = this.model_name || "gpt-4o-mini"; const pack = { - model: this.model_name || "gpt-3.5-turbo", + model: model, messages, stop: stop_seq, ...(this.params || {}) }; - if (this.model_name.includes('o1') || this.model_name.includes('o3') || this.model_name.includes('5')) { + if (model.includes('o1') || model.includes('o3') || model.includes('5')) { delete pack.stop; } let res = null; try { - console.log('Awaiting openai api response from model', this.model_name) + console.log('Awaiting openai api response from model', model) // console.log('Messages:', messages); let completion = await this.openai.chat.completions.create(pack); if (completion.choices[0].finish_reason == 'length') @@ -88,6 +90,3 @@ export class GPT { } } - - - diff --git a/src/models/grok.js b/src/models/grok.js index 2878a10..0753f10 100644 --- a/src/models/grok.js +++ b/src/models/grok.js @@ -3,6 +3,7 @@ import { getKey } from '../utils/keys.js'; // xAI doesn't supply a SDK for their models, but fully supports OpenAI and Anthropic SDKs export class Grok { + static prefix = 'grok'; constructor(model_name, url, params) { this.model_name = model_name; this.url = url; diff --git a/src/models/groq.js b/src/models/groq.js index e4e8f3b..9da88c7 100644 --- a/src/models/groq.js +++ b/src/models/groq.js @@ -6,6 +6,7 @@ import { getKey } from '../utils/keys.js'; // Umbrella class for everything under the sun... That GroqCloud provides, that is. export class GroqCloudAPI { + static prefix = 'groq'; constructor(model_name, url, params) { @@ -63,7 +64,6 @@ export class GroqCloudAPI { if (err.message.includes("content must be a string")) { res = "Vision is only supported by certain models."; } else { - console.log(this.model_name); res = "My brain disconnected, try again."; } console.log(err); diff --git a/src/models/huggingface.js b/src/models/huggingface.js index 80c36e8..91fbdfd 100644 --- a/src/models/huggingface.js +++ b/src/models/huggingface.js @@ -3,6 +3,7 @@ import { getKey } from '../utils/keys.js'; import { HfInference } from "@huggingface/inference"; export class HuggingFace { + static prefix = 'huggingface'; constructor(model_name, url, params) { // Remove 'huggingface/' prefix if present this.model_name = model_name.replace('huggingface/', ''); diff --git a/src/models/hyperbolic.js b/src/models/hyperbolic.js index a2ccc48..f483b69 100644 --- a/src/models/hyperbolic.js +++ b/src/models/hyperbolic.js @@ -1,6 +1,7 @@ import { getKey } from '../utils/keys.js'; export class Hyperbolic { + static prefix = 'hyperbolic'; constructor(modelName, apiUrl) { this.modelName = modelName || "deepseek-ai/DeepSeek-V3"; this.apiUrl = apiUrl || "https://api.hyperbolic.xyz/v1/chat/completions"; diff --git a/src/models/mistral.js b/src/models/mistral.js index 72448f1..536b386 100644 --- a/src/models/mistral.js +++ b/src/models/mistral.js @@ -3,6 +3,7 @@ import { getKey } from '../utils/keys.js'; import { strictFormat } from '../utils/text.js'; export class Mistral { + static prefix = 'mistral'; #client; constructor(model_name, url, params) { diff --git a/src/models/novita.js b/src/models/novita.js index 8f2dd08..380fa4c 100644 --- a/src/models/novita.js +++ b/src/models/novita.js @@ -4,6 +4,7 @@ import { strictFormat } from '../utils/text.js'; // llama, mistral export class Novita { + static prefix = 'novita'; constructor(model_name, url, params) { this.model_name = model_name.replace('novita/', ''); this.url = url || 'https://api.novita.ai/v3/openai'; diff --git a/src/models/ollama.js b/src/models/ollama.js index 064d2ad..37d8557 100644 --- a/src/models/ollama.js +++ b/src/models/ollama.js @@ -1,6 +1,7 @@ import { strictFormat } from '../utils/text.js'; export class Ollama { + static prefix = 'ollama'; constructor(model_name, url, params) { this.model_name = model_name; this.params = params; diff --git a/src/models/openrouter.js b/src/models/openrouter.js index 5cbc090..ca0782b 100644 --- a/src/models/openrouter.js +++ b/src/models/openrouter.js @@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js'; import { strictFormat } from '../utils/text.js'; export class OpenRouter { + static prefix = 'openrouter'; constructor(model_name, url) { this.model_name = model_name; diff --git a/src/models/qwen.js b/src/models/qwen.js index 4dfacfe..a768b5b 100644 --- a/src/models/qwen.js +++ b/src/models/qwen.js @@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js'; import { strictFormat } from '../utils/text.js'; export class Qwen { + static prefix = 'qwen'; constructor(model_name, url, params) { this.model_name = model_name; this.params = params; diff --git a/src/models/replicate.js b/src/models/replicate.js index c8c3ba3..aa296c5 100644 --- a/src/models/replicate.js +++ b/src/models/replicate.js @@ -4,6 +4,7 @@ import { getKey } from '../utils/keys.js'; // llama, mistral export class ReplicateAPI { + static prefix = 'replicate'; constructor(model_name, url, params) { this.model_name = model_name; this.url = url; diff --git a/src/models/vllm.js b/src/models/vllm.js index e9116ef..d821983 100644 --- a/src/models/vllm.js +++ b/src/models/vllm.js @@ -6,6 +6,7 @@ import { getKey, hasKey } from '../utils/keys.js'; import { strictFormat } from '../utils/text.js'; export class VLLM { + static prefix = 'vllm'; constructor(model_name, url) { this.model_name = model_name; @@ -23,13 +24,14 @@ export class VLLM { async sendRequest(turns, systemMessage, stop_seq = '***') { let messages = [{ 'role': 'system', 'content': systemMessage }].concat(turns); + let model = this.model_name || "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"; - if (this.model_name.includes('deepseek') || this.model_name.includes('qwen')) { + if (model.includes('deepseek') || model.includes('qwen')) { messages = strictFormat(messages); } const pack = { - model: this.model_name || "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", + model: model, messages, stop: stop_seq, };