mirror of
https://github.com/kolbytn/mindcraft.git
synced 2025-08-26 09:03:43 +02:00
dynamically load models
This commit is contained in:
parent
2a38d310fc
commit
14eff85120
17 changed files with 85 additions and 70 deletions
|
@ -1,73 +1,74 @@
|
|||
import { Gemini } from './gemini.js';
|
||||
import { GPT } from './gpt.js';
|
||||
import { Claude } from './claude.js';
|
||||
import { Mistral } from './mistral.js';
|
||||
import { ReplicateAPI } from './replicate.js';
|
||||
import { Ollama } from './ollama.js';
|
||||
import { Novita } from './novita.js';
|
||||
import { GroqCloudAPI } from './groq.js';
|
||||
import { HuggingFace } from './huggingface.js';
|
||||
import { Qwen } from "./qwen.js";
|
||||
import { Grok } from "./grok.js";
|
||||
import { DeepSeek } from './deepseek.js';
|
||||
import { Hyperbolic } from './hyperbolic.js';
|
||||
import { GLHF } from './glhf.js';
|
||||
import { OpenRouter } from './openrouter.js';
|
||||
import { VLLM } from './vllm.js';
|
||||
import { promises as fs } from 'fs';
|
||||
import path from 'path';
|
||||
import { fileURLToPath, pathToFileURL } from 'url';
|
||||
|
||||
// Add new models here.
|
||||
// It maps api prefixes to model classes, eg 'openai/gpt-4o' -> GPT
|
||||
const apiMap = {
|
||||
'openai': GPT,
|
||||
'google': Gemini,
|
||||
'anthropic': Claude,
|
||||
'replicate': ReplicateAPI,
|
||||
'ollama': Ollama,
|
||||
'mistral': Mistral,
|
||||
'groq': GroqCloudAPI,
|
||||
'huggingface': HuggingFace,
|
||||
'novita': Novita,
|
||||
'qwen': Qwen,
|
||||
'grok': Grok,
|
||||
'deepseek': DeepSeek,
|
||||
'hyperbolic': Hyperbolic,
|
||||
'glhf': GLHF,
|
||||
'openrouter': OpenRouter,
|
||||
'vllm': VLLM,
|
||||
}
|
||||
const __filename = fileURLToPath(import.meta.url);
|
||||
const __dirname = path.dirname(__filename);
|
||||
|
||||
// Dynamically discover model classes in this directory.
|
||||
// Each model class must export a static `prefix` string.
|
||||
const apiMap = await (async () => {
|
||||
const map = {};
|
||||
const files = (await fs.readdir(__dirname))
|
||||
.filter(f => f.endsWith('.js') && f !== '_model_map.js' && f !== 'prompter.js');
|
||||
for (const file of files) {
|
||||
try {
|
||||
const moduleUrl = pathToFileURL(path.join(__dirname, file)).href;
|
||||
const mod = await import(moduleUrl);
|
||||
for (const exported of Object.values(mod)) {
|
||||
if (typeof exported === 'function' && Object.prototype.hasOwnProperty.call(exported, 'prefix')) {
|
||||
const prefix = exported.prefix;
|
||||
if (typeof prefix === 'string' && prefix.length > 0) {
|
||||
map[prefix] = exported;
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
console.warn('Failed to load model module:', file, e?.message || e);
|
||||
}
|
||||
}
|
||||
return map;
|
||||
})();
|
||||
|
||||
export function selectAPI(profile) {
|
||||
if (typeof profile === 'string' || profile instanceof String) {
|
||||
profile = {model: profile};
|
||||
}
|
||||
const api = Object.keys(apiMap).find(key => profile.model.startsWith(key));
|
||||
if (api) {
|
||||
profile.api = api;
|
||||
}
|
||||
else {
|
||||
// backwards compatibility with local->ollama
|
||||
if (profile.model.includes('local')) {
|
||||
profile.api = 'ollama';
|
||||
profile.model = profile.model.replace('local/', '');
|
||||
// backwards compatibility with local->ollama
|
||||
if (profile.api?.includes('local') || profile.model?.includes('local')) {
|
||||
profile.api = 'ollama';
|
||||
if (profile.model) {
|
||||
profile.model = profile.model.replace('local', 'ollama');
|
||||
}
|
||||
// check for some common models that do not require prefixes
|
||||
else if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3'))
|
||||
profile.api = 'openai';
|
||||
else if (profile.model.includes('claude'))
|
||||
profile.api = 'anthropic';
|
||||
else if (profile.model.includes('gemini'))
|
||||
profile.api = "google";
|
||||
else if (profile.model.includes('grok'))
|
||||
profile.api = 'grok';
|
||||
else if (profile.model.includes('mistral'))
|
||||
profile.api = 'mistral';
|
||||
else if (profile.model.includes('deepseek'))
|
||||
profile.api = 'deepseek';
|
||||
else if (profile.model.includes('qwen'))
|
||||
profile.api = 'qwen';
|
||||
}
|
||||
if (!profile.api) {
|
||||
throw new Error('Unknown model:', profile.model);
|
||||
const api = Object.keys(apiMap).find(key => profile.model?.startsWith(key));
|
||||
if (api) {
|
||||
profile.api = api;
|
||||
}
|
||||
else {
|
||||
// check for some common models that do not require prefixes
|
||||
if (profile.model.includes('gpt') || profile.model.includes('o1')|| profile.model.includes('o3'))
|
||||
profile.api = 'openai';
|
||||
else if (profile.model.includes('claude'))
|
||||
profile.api = 'anthropic';
|
||||
else if (profile.model.includes('gemini'))
|
||||
profile.api = "google";
|
||||
else if (profile.model.includes('grok'))
|
||||
profile.api = 'grok';
|
||||
else if (profile.model.includes('mistral'))
|
||||
profile.api = 'mistral';
|
||||
else if (profile.model.includes('deepseek'))
|
||||
profile.api = 'deepseek';
|
||||
else if (profile.model.includes('qwen'))
|
||||
profile.api = 'qwen';
|
||||
}
|
||||
if (!profile.api) {
|
||||
throw new Error('Unknown model:', profile.model);
|
||||
}
|
||||
}
|
||||
if (!apiMap[profile.api]) {
|
||||
throw new Error('Unknown api:', profile.api);
|
||||
}
|
||||
let model_name = profile.model.replace(profile.api + '/', ''); // remove prefix
|
||||
profile.model = model_name === "" ? null : model_name; // if model is empty, set to null
|
||||
|
|
|
@ -3,6 +3,7 @@ import { strictFormat } from '../utils/text.js';
|
|||
import { getKey } from '../utils/keys.js';
|
||||
|
||||
export class Claude {
|
||||
static prefix = 'anthropic';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.params = params || {};
|
||||
|
|
|
@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js';
|
|||
import { strictFormat } from '../utils/text.js';
|
||||
|
||||
export class DeepSeek {
|
||||
static prefix = 'deepseek';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.params = params;
|
||||
|
|
|
@ -3,6 +3,7 @@ import { toSinglePrompt, strictFormat } from '../utils/text.js';
|
|||
import { getKey } from '../utils/keys.js';
|
||||
|
||||
export class Gemini {
|
||||
static prefix = 'google';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.params = params;
|
||||
|
|
|
@ -2,6 +2,7 @@ import OpenAIApi from 'openai';
|
|||
import { getKey } from '../utils/keys.js';
|
||||
|
||||
export class GLHF {
|
||||
static prefix = 'glhf';
|
||||
constructor(model_name, url) {
|
||||
this.model_name = model_name;
|
||||
const apiKey = getKey('GHLF_API_KEY');
|
||||
|
|
|
@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js';
|
|||
import { strictFormat } from '../utils/text.js';
|
||||
|
||||
export class GPT {
|
||||
static prefix = 'openai';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.params = params;
|
||||
|
@ -22,20 +23,21 @@ export class GPT {
|
|||
async sendRequest(turns, systemMessage, stop_seq='***') {
|
||||
let messages = [{'role': 'system', 'content': systemMessage}].concat(turns);
|
||||
messages = strictFormat(messages);
|
||||
let model = this.model_name || "gpt-4o-mini";
|
||||
const pack = {
|
||||
model: this.model_name || "gpt-3.5-turbo",
|
||||
model: model,
|
||||
messages,
|
||||
stop: stop_seq,
|
||||
...(this.params || {})
|
||||
};
|
||||
if (this.model_name.includes('o1') || this.model_name.includes('o3') || this.model_name.includes('5')) {
|
||||
if (model.includes('o1') || model.includes('o3') || model.includes('5')) {
|
||||
delete pack.stop;
|
||||
}
|
||||
|
||||
let res = null;
|
||||
|
||||
try {
|
||||
console.log('Awaiting openai api response from model', this.model_name)
|
||||
console.log('Awaiting openai api response from model', model)
|
||||
// console.log('Messages:', messages);
|
||||
let completion = await this.openai.chat.completions.create(pack);
|
||||
if (completion.choices[0].finish_reason == 'length')
|
||||
|
@ -88,6 +90,3 @@ export class GPT {
|
|||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import { getKey } from '../utils/keys.js';
|
|||
|
||||
// xAI doesn't supply a SDK for their models, but fully supports OpenAI and Anthropic SDKs
|
||||
export class Grok {
|
||||
static prefix = 'grok';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.url = url;
|
||||
|
|
|
@ -6,6 +6,7 @@ import { getKey } from '../utils/keys.js';
|
|||
|
||||
// Umbrella class for everything under the sun... That GroqCloud provides, that is.
|
||||
export class GroqCloudAPI {
|
||||
static prefix = 'groq';
|
||||
|
||||
constructor(model_name, url, params) {
|
||||
|
||||
|
@ -63,7 +64,6 @@ export class GroqCloudAPI {
|
|||
if (err.message.includes("content must be a string")) {
|
||||
res = "Vision is only supported by certain models.";
|
||||
} else {
|
||||
console.log(this.model_name);
|
||||
res = "My brain disconnected, try again.";
|
||||
}
|
||||
console.log(err);
|
||||
|
|
|
@ -3,6 +3,7 @@ import { getKey } from '../utils/keys.js';
|
|||
import { HfInference } from "@huggingface/inference";
|
||||
|
||||
export class HuggingFace {
|
||||
static prefix = 'huggingface';
|
||||
constructor(model_name, url, params) {
|
||||
// Remove 'huggingface/' prefix if present
|
||||
this.model_name = model_name.replace('huggingface/', '');
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import { getKey } from '../utils/keys.js';
|
||||
|
||||
export class Hyperbolic {
|
||||
static prefix = 'hyperbolic';
|
||||
constructor(modelName, apiUrl) {
|
||||
this.modelName = modelName || "deepseek-ai/DeepSeek-V3";
|
||||
this.apiUrl = apiUrl || "https://api.hyperbolic.xyz/v1/chat/completions";
|
||||
|
|
|
@ -3,6 +3,7 @@ import { getKey } from '../utils/keys.js';
|
|||
import { strictFormat } from '../utils/text.js';
|
||||
|
||||
export class Mistral {
|
||||
static prefix = 'mistral';
|
||||
#client;
|
||||
|
||||
constructor(model_name, url, params) {
|
||||
|
|
|
@ -4,6 +4,7 @@ import { strictFormat } from '../utils/text.js';
|
|||
|
||||
// llama, mistral
|
||||
export class Novita {
|
||||
static prefix = 'novita';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name.replace('novita/', '');
|
||||
this.url = url || 'https://api.novita.ai/v3/openai';
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import { strictFormat } from '../utils/text.js';
|
||||
|
||||
export class Ollama {
|
||||
static prefix = 'ollama';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.params = params;
|
||||
|
|
|
@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js';
|
|||
import { strictFormat } from '../utils/text.js';
|
||||
|
||||
export class OpenRouter {
|
||||
static prefix = 'openrouter';
|
||||
constructor(model_name, url) {
|
||||
this.model_name = model_name;
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ import { getKey, hasKey } from '../utils/keys.js';
|
|||
import { strictFormat } from '../utils/text.js';
|
||||
|
||||
export class Qwen {
|
||||
static prefix = 'qwen';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.params = params;
|
||||
|
|
|
@ -4,6 +4,7 @@ import { getKey } from '../utils/keys.js';
|
|||
|
||||
// llama, mistral
|
||||
export class ReplicateAPI {
|
||||
static prefix = 'replicate';
|
||||
constructor(model_name, url, params) {
|
||||
this.model_name = model_name;
|
||||
this.url = url;
|
||||
|
|
|
@ -6,6 +6,7 @@ import { getKey, hasKey } from '../utils/keys.js';
|
|||
import { strictFormat } from '../utils/text.js';
|
||||
|
||||
export class VLLM {
|
||||
static prefix = 'vllm';
|
||||
constructor(model_name, url) {
|
||||
this.model_name = model_name;
|
||||
|
||||
|
@ -23,13 +24,14 @@ export class VLLM {
|
|||
|
||||
async sendRequest(turns, systemMessage, stop_seq = '***') {
|
||||
let messages = [{ 'role': 'system', 'content': systemMessage }].concat(turns);
|
||||
let model = this.model_name || "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B";
|
||||
|
||||
if (this.model_name.includes('deepseek') || this.model_name.includes('qwen')) {
|
||||
if (model.includes('deepseek') || model.includes('qwen')) {
|
||||
messages = strictFormat(messages);
|
||||
}
|
||||
|
||||
const pack = {
|
||||
model: this.model_name || "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
|
||||
model: model,
|
||||
messages,
|
||||
stop: stop_seq,
|
||||
};
|
||||
|
|
Loading…
Add table
Reference in a new issue