Merge pull request #58 from kolbytn/model-refactor

model refactor
This commit is contained in:
Kolby Nottingham 2024-04-28 14:15:08 -07:00 committed by GitHub
commit 037f58234c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 206 additions and 95 deletions

View file

@ -38,8 +38,51 @@ Run `node main.js`
You can configure the agent's name, model, and prompts in their profile like `andy.json`.
You can configure project details in `settings.json`.
You can configure project details in `settings.json`.
## Bot Profiles
Bot profiles are json files (such as `andy.json`) that define a bot's behavior in three ways:
1. Bot backend LLMs to use for chat and embeddings.
2. Prompts used to influence the bot's behavior.
3. Examples retrieved and provided to the bot to help it better perform tasks.
### Model Specifications
LLM backends can be specified as simply as `"model": "gpt-3.5-turbo"`. However, for both the chat model and the embedding model, the bot profile can specify the below attributes:
```
"model": {
"api": "openai",
"url": "https://api.openai.com/v1/",
"model": "gpt-3.5-turbo"
},
"embedding": {
"api": "openai",
"url": "https://api.openai.com/v1/",
"model": "text-embedding-ada-002"
}
```
The model parameter accepts either a string or object. If a string, it should specify the model to be used. The api and url will be assumed. If an object, the api field must be specified. Each api has a default model and url, so those fields are optional.
If the embedding field is not specified, then it will use the default embedding method for the chat model's api (Note that anthropic has no embedding model). The embedding parameter can also be a string or object. If a string, it should specify the embedding api and the default model and url will be used. If a valid embedding is not specified and cannot be assumed, then word overlap will be used to retrieve examples instead.
Thus, all the below specifications are equivalent to the above example:
```
"model": "gpt-3.5-turbo"
```
```
"model": {
"api": "openai"
}
```
```
"model": "gpt-3.5-turbo",
"embedding": "openai"
```
## Online Servers
To connect to online servers your bot will need an official Microsoft/Minecraft account. You can use your own personal one, but will need another account if you want to connect with it. Here is an example settings for this:

View file

@ -13,10 +13,57 @@ import { Local } from '../models/local.js';
export class Prompter {
constructor(agent, fp) {
this.prompts = JSON.parse(readFileSync(fp, 'utf8'));
let name = this.prompts.name;
this.agent = agent;
let model_name = this.prompts.model;
this.prompts = JSON.parse(readFileSync(fp, 'utf8'));
this.convo_examples = null;
this.coding_examples = null;
let name = this.prompts.name;
let chat = this.prompts.model;
if (typeof chat === 'string' || chat instanceof String) {
chat = {model: chat};
if (chat.model.includes('gemini'))
chat.api = 'google';
else if (chat.model.includes('gpt'))
chat.api = 'openai';
else if (chat.model.includes('claude'))
chat.api = 'anthropic';
else
chat.api = 'ollama';
}
console.log('Using chat settings:', chat);
if (chat.api == 'google')
this.chat_model = new Gemini(chat.model, chat.url);
else if (chat.api == 'openai')
this.chat_model = new GPT(chat.model, chat.url);
else if (chat.api == 'anthropic')
this.chat_model = new Claude(chat.model, chat.url);
else if (chat.api == 'ollama')
this.chat_model = new Local(chat.model, chat.url);
else
throw new Error('Unknown API:', api);
let embedding = this.prompts.embedding;
if (embedding === undefined)
embedding = {api: chat.api};
else if (typeof embedding === 'string' || embedding instanceof String)
embedding = {api: embedding};
console.log('Using embedding settings:', embedding);
if (embedding.api == 'google')
this.embedding_model = new Gemini(embedding.model, embedding.url);
else if (embedding.api == 'openai')
this.embedding_model = new GPT(embedding.model, embedding.url);
else if (embedding.api == 'ollama')
this.embedding_model = new Local(embedding.model, embedding.url);
else {
this.embedding_model = null;
console.log('Unknown embedding: ', embedding ? embedding.api : '[NOT SPECIFIED]', '. Using word overlap.');
}
mkdirSync(`./bots/${name}`, { recursive: true });
writeFileSync(`./bots/${name}/last_profile.json`, JSON.stringify(this.prompts, null, 4), (err) => {
if (err) {
@ -24,15 +71,6 @@ export class Prompter {
}
console.log("Copy profile saved.");
});
if (model_name.includes('gemini'))
this.model = new Gemini(model_name);
else if (model_name.includes('gpt'))
this.model = new GPT(model_name);
else if (model_name.includes('claude'))
this.model = new Claude(model_name);
else
this.model = new Local(model_name);
}
getName() {
@ -41,9 +79,9 @@ export class Prompter {
async initExamples() {
console.log('Loading examples...')
this.convo_examples = new Examples(this.model);
this.convo_examples = new Examples(this.embedding_model);
await this.convo_examples.load(this.prompts.conversation_examples);
this.coding_examples = new Examples(this.model);
this.coding_examples = new Examples(this.embedding_model);
await this.coding_examples.load(this.prompts.coding_examples);
console.log('Examples loaded.');
}
@ -102,19 +140,19 @@ export class Prompter {
async promptConvo(messages) {
let prompt = this.prompts.conversing;
prompt = await this.replaceStrings(prompt, messages, this.convo_examples);
return await this.model.sendRequest(messages, prompt);
return await this.chat_model.sendRequest(messages, prompt);
}
async promptCoding(messages) {
let prompt = this.prompts.coding;
prompt = await this.replaceStrings(prompt, messages, this.coding_examples);
return await this.model.sendRequest(messages, prompt);
return await this.chat_model.sendRequest(messages, prompt);
}
async promptMemSaving(prev_mem, to_summarize) {
let prompt = this.prompts.saving_memory;
prompt = await this.replaceStrings(prompt, null, null, prev_mem, to_summarize);
return await this.model.sendRequest([], prompt);
return await this.chat_model.sendRequest([], prompt);
}
async promptGoalSetting(messages, last_goals) {

View file

@ -1,23 +1,19 @@
import Anthropic from '@anthropic-ai/sdk';
import { GPT } from './gpt.js';
export class Claude {
constructor(model_name) {
constructor(model_name, url) {
this.model_name = model_name;
if (!process.env.ANTHROPIC_API_KEY) {
let config = {};
if (url)
config.baseURL = url;
if (process.env.ANTHROPIC_API_KEY)
config.apiKey = process.env["ANTHROPIC_API_KEY"];
else
throw new Error('Anthropic API key missing! Make sure you set your ANTHROPIC_API_KEY environment variable.');
}
this.anthropic = new Anthropic({
apiKey: process.env["ANTHROPIC_API_KEY"]
});
this.gpt = undefined;
try {
this.gpt = new GPT(); // use for embeddings, ignore model
} catch (err) {
console.warn('Claude uses the OpenAI API for embeddings, but no OPENAI_API_KEY env variable was found. Claude will still work, but performance will suffer.');
}
this.anthropic = new Anthropic(config);
}
async sendRequest(turns, systemMessage) {
@ -56,7 +52,7 @@ export class Claude {
console.log('Awaiting anthropic api response...')
console.log('Messages:', messages);
const resp = await this.anthropic.messages.create({
model: this.model_name,
model: this.model_name || "claude-3-sonnet-20240229",
system: systemMessage,
max_tokens: 2048,
messages: messages,
@ -72,11 +68,7 @@ export class Claude {
}
async embed(text) {
if (this.gpt) {
return await this.gpt.embed(text);
}
// if no gpt, just return random embedding
return Array(1).fill().map(() => Math.random());
throw new Error('Embeddings are not supported by Claude.');
}
}

View file

@ -1,18 +1,29 @@
import { GoogleGenerativeAI } from '@google/generative-ai';
export class Gemini {
constructor(model_name) {
constructor(model_name, url) {
this.model_name = model_name;
this.url = url;
if (!process.env.GEMINI_API_KEY) {
throw new Error('Gemini API key missing! Make sure you set your GEMINI_API_KEY environment variable.');
}
this.genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY);
this.llmModel = this.genAI.getGenerativeModel({ model: model_name });
this.embedModel = this.genAI.getGenerativeModel({ model: "embedding-001"});
}
async sendRequest(turns, systemMessage) {
if (this.url) {
model = this.genAI.getGenerativeModel(
{model: this.model_name || "gemini-pro"},
{baseUrl: this.url}
);
} else {
model = this.genAI.getGenerativeModel(
{model: this.model_name || "gemini-pro"}
);
}
const messages = [{'role': 'system', 'content': systemMessage}].concat(turns);
let prompt = "";
let role = "";
@ -24,13 +35,24 @@ export class Gemini {
if (role !== "model") // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message
prompt += "model: ";
console.log(prompt)
const result = await this.llmModel.generateContent(prompt);
const result = await model.generateContent(prompt);
const response = await result.response;
return response.text();
}
async embed(text) {
const result = await this.embedModel.embedContent(text);
if (this.url) {
model = this.genAI.getGenerativeModel(
{model: this.model_name || "embedding-001"},
{baseUrl: this.url}
);
} else {
model = this.genAI.getGenerativeModel(
{model: this.model_name || "embedding-001"}
);
}
const result = await model.embedContent(text);
return result.embedding;
}
}

View file

@ -1,25 +1,21 @@
import OpenAIApi from 'openai';
export class GPT {
constructor(model_name) {
this.model_name = model_name;
let openAiConfig = null;
if (process.env.OPENAI_ORG_ID) {
openAiConfig = {
organization: process.env.OPENAI_ORG_ID,
apiKey: process.env.OPENAI_API_KEY,
};
}
else if (process.env.OPENAI_API_KEY) {
openAiConfig = {
apiKey: process.env.OPENAI_API_KEY,
};
}
else {
throw new Error('OpenAI API key missing! Make sure you set your OPENAI_API_KEY environment variable.');
}
this.openai = new OpenAIApi(openAiConfig);
export class GPT {
constructor(model_name, url) {
this.model_name = model_name;
let config = {};
if (url)
config.baseURL = url;
if (process.env.OPENAI_ORG_ID)
config.organization = process.env.OPENAI_ORG_ID
if (process.env.OPENAI_API_KEY)
config.apiKey = process.env.OPENAI_API_KEY
else
throw new Error('OpenAI API key missing! Make sure you set your OPENAI_API_KEY environment variable.');
this.openai = new OpenAIApi(config);
}
async sendRequest(turns, systemMessage, stop_seq='***') {
@ -31,7 +27,7 @@ export class GPT {
console.log('Awaiting openai api response...')
console.log('Messages:', messages);
let completion = await this.openai.chat.completions.create({
model: this.model_name,
model: this.model_name || "gpt-3.5-turbo",
messages: messages,
stop: stop_seq,
});
@ -54,7 +50,7 @@ export class GPT {
async embed(text) {
const embedding = await this.openai.embeddings.create({
model: "text-embedding-ada-002",
model: this.model_name || "text-embedding-ada-002",
input: text,
encoding_format: "float",
});

View file

@ -1,20 +1,20 @@
export class Local {
constructor(model_name) {
constructor(model_name, url) {
this.model_name = model_name;
this.embedding_model = 'nomic-embed-text';
this.url = 'http://localhost:11434';
this.url = url || 'http://localhost:11434';
this.chat_endpoint = '/api/chat';
this.embedding_endpoint = '/api/embeddings';
}
async sendRequest(turns, systemMessage) {
let model = this.model_name || 'llama3';
let messages = [{'role': 'system', 'content': systemMessage}].concat(turns);
let res = null;
try {
console.log(`Awaiting local response... (model: ${this.model_name})`)
console.log(`Awaiting local response... (model: ${model})`)
console.log('Messages:', messages);
res = await this.send(this.chat_endpoint, {model: this.model_name, messages: messages, stream: false});
res = await this.send(this.chat_endpoint, {model: model, messages: messages, stream: false});
if (res)
res = res['message']['content'];
}
@ -31,12 +31,12 @@ export class Local {
}
async embed(text) {
let body = {model: this.embedding_model, prompt: text};
let model = this.model_name || 'nomic-embed-text';
let body = {model: model, prompt: text};
let res = await this.send(this.embedding_endpoint, body);
return res['embedding']
}
async send(endpoint, body) {
const url = new URL(endpoint, this.url);
let method = 'POST';

View file

@ -6,34 +6,54 @@ export class Examples {
this.examples = [];
this.model = model;
this.select_num = select_num;
this.embeddings = {};
}
turnsToText(turns) {
let messages = '';
for (let turn of turns) {
if (turn.role !== 'assistant')
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
}
return messages.trim();
}
getWords(text) {
return text.replace(/[^a-zA-Z ]/g, '').toLowerCase().split(' ');
}
wordOverlapScore(text1, text2) {
const words1 = this.getWords(text1);
const words2 = this.getWords(text2);
const intersection = words1.filter(word => words2.includes(word));
return intersection.length / (words1.length + words2.length - intersection.length);
}
async load(examples) {
this.examples = [];
let promises = examples.map(async (example) => {
let messages = '';
for (let turn of example) {
if (turn.role === 'user')
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
this.examples = examples;
if (this.model !== null) {
for (let example of this.examples) {
let turn_text = this.turnsToText(example);
this.embeddings[turn_text] = await this.model.embed(turn_text);
}
messages = messages.trim();
const embedding = await this.model.embed(messages);
return {'embedding': embedding, 'turns': example};
});
this.examples = await Promise.all(promises);
}
}
async getRelevant(turns) {
let messages = '';
for (let turn of turns) {
if (turn.role != 'assistant')
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
let turn_text = this.turnsToText(turns);
if (this.model !== null) {
let embedding = await this.model.embed(turn_text);
this.examples.sort((a, b) =>
cosineSimilarity(embedding, this.embeddings[this.turnsToText(b)]) -
cosineSimilarity(embedding, this.embeddings[this.turnsToText(a)])
);
}
else {
this.examples.sort((a, b) =>
this.wordOverlapScore(turn_text, this.turnsToText(b)) -
this.wordOverlapScore(turn_text, this.turnsToText(a))
);
}
messages = messages.trim();
const embedding = await this.model.embed(messages);
this.examples.sort((a, b) => {
return cosineSimilarity(b.embedding, embedding) - cosineSimilarity(a.embedding, embedding);
});
let selected = this.examples.slice(0, this.select_num);
return JSON.parse(JSON.stringify(selected)); // deep copy
}
@ -43,13 +63,13 @@ export class Examples {
console.log('selected examples:');
for (let example of selected_examples) {
console.log(example.turns[0].content)
console.log(example[0].content)
}
let msg = 'Examples of how to respond:\n';
for (let i=0; i<selected_examples.length; i++) {
let example = selected_examples[i];
msg += `Example ${i+1}:\n${stringifyTurns(example.turns)}\n\n`;
msg += `Example ${i+1}:\n${stringifyTurns(example)}\n\n`;
}
return msg;
}