From 5bf147fc40d62c0ce362f7ba1722ce1b45618d53 Mon Sep 17 00:00:00 2001 From: MaxRobinsonTheGreat Date: Sun, 18 Feb 2024 22:56:38 -0600 Subject: [PATCH] refactored llm, added gemini --- package.json | 1 + settings.json | 2 ++ src/agent/agent.js | 10 +++--- src/agent/coder.js | 2 +- src/agent/history.js | 2 +- src/models/gemini.js | 37 +++++++++++++++++++++ src/models/gpt.js | 66 +++++++++++++++++++++++++++++++++++++ src/models/model.js | 19 +++++++++++ src/utils/examples.js | 3 +- src/utils/gpt.js | 75 ------------------------------------------- src/utils/math.js | 13 ++++++++ 11 files changed, 148 insertions(+), 82 deletions(-) create mode 100644 src/models/gemini.js create mode 100644 src/models/gpt.js create mode 100644 src/models/model.js delete mode 100644 src/utils/gpt.js create mode 100644 src/utils/math.js diff --git a/package.json b/package.json index ee33776..3b42257 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,7 @@ { "type": "module", "dependencies": { + "@google/generative-ai": "^0.2.1", "minecraft-data": "^3.46.2", "mineflayer": "^4.14.0", "mineflayer-armor-manager": "^2.0.1", diff --git a/settings.json b/settings.json index bfd07a4..4818718 100644 --- a/settings.json +++ b/settings.json @@ -3,5 +3,7 @@ "host": "localhost", "port": 55916, "auth": "offline", + + "model": "gemini-1.0-pro", "allow_insecure_coding": false } \ No newline at end of file diff --git a/src/agent/agent.js b/src/agent/agent.js index dfaf425..cf2cc93 100644 --- a/src/agent/agent.js +++ b/src/agent/agent.js @@ -3,7 +3,7 @@ import { Coder } from './coder.js'; import { initModes } from './modes.js'; import { Examples } from '../utils/examples.js'; import { initBot } from '../utils/mcdata.js'; -import { sendRequest } from '../utils/gpt.js'; +import { sendRequest } from '../models/model.js'; import { containsCommand, commandExists, executeCommand } from './commands/index.js'; @@ -102,7 +102,7 @@ export class Agent { let command_name = containsCommand(res); if (command_name) { // contains query or command - console.log('Command message:', res); + console.log(`""${res}""`) if (!commandExists(command_name)) { this.history.add('system', `Command ${command_name} does not exist. Use !newAction to perform custom actions.`); console.log('Agent hallucinated command:', command_name) @@ -110,8 +110,10 @@ export class Agent { } let pre_message = res.substring(0, res.indexOf(command_name)).trim(); - - this.cleanChat(`${pre_message} *used ${command_name.substring(1)}*`); + let message = `*used ${command_name.substring(1)}*`; + if (pre_message.length > 0) + message = `${pre_message} ${message}`; + this.cleanChat(message); let execute_res = await executeCommand(this, res); console.log('Agent executed:', command_name, 'and got:', execute_res); diff --git a/src/agent/coder.js b/src/agent/coder.js index b4ca023..9a09158 100644 --- a/src/agent/coder.js +++ b/src/agent/coder.js @@ -1,5 +1,5 @@ import { writeFile, readFile, mkdirSync } from 'fs'; -import { sendRequest } from '../utils/gpt.js'; +import { sendRequest } from '../models/model.js'; import { getSkillDocs } from './library/index.js'; import { Examples } from '../utils/examples.js'; diff --git a/src/agent/history.js b/src/agent/history.js index 6617812..9bffca7 100644 --- a/src/agent/history.js +++ b/src/agent/history.js @@ -1,6 +1,6 @@ import { writeFileSync, readFileSync, mkdirSync } from 'fs'; import { stringifyTurns } from '../utils/text.js'; -import { sendRequest } from '../utils/gpt.js'; +import { sendRequest } from '../models/model.js'; import { getCommandDocs } from './commands/index.js'; diff --git a/src/models/gemini.js b/src/models/gemini.js new file mode 100644 index 0000000..6f8429d --- /dev/null +++ b/src/models/gemini.js @@ -0,0 +1,37 @@ +import { GoogleGenerativeAI } from '@google/generative-ai'; +import settings from '../settings.js'; + +export class Gemini { + constructor() { + if (!process.env.GEMINI_API_KEY) { + console.error('Gemini API key missing! Make sure you set your GEMINI_API_KEY environment variable.'); + process.exit(1); + } + this.genAI = new GoogleGenerativeAI(process.env.GEMINI_API_KEY); + + this.model = this.genAI.getGenerativeModel({ model: settings.model }); + } + + async sendRequest(turns, systemMessage) { + const messages = [{'role': 'system', 'content': systemMessage}].concat(turns); + let prompt = ""; + let role = ""; + messages.forEach((message) => { + role = message.role; + if (role === 'assistant') role = 'model'; + prompt += `${role}: ${message.content}\n`; + }); + if (role !== "model") // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message + prompt += "model: "; + console.log(prompt) + const result = await this.model.generateContent(prompt); + const response = await result.response; + return response.text(); + } + + async embed(text) { + const model = this.genAI.getGenerativeModel({ model: "embedding-001"}); + const result = await model.embedContent(text); + return result.embedding; + } +} \ No newline at end of file diff --git a/src/models/gpt.js b/src/models/gpt.js new file mode 100644 index 0000000..e6fc120 --- /dev/null +++ b/src/models/gpt.js @@ -0,0 +1,66 @@ +import OpenAIApi from 'openai'; +import settings from '../settings.js'; + +export class GPT { + constructor() { + let openAiConfig = null; + if (process.env.OPENAI_ORG_ID) { + openAiConfig = { + organization: process.env.OPENAI_ORG_ID, + apiKey: process.env.OPENAI_API_KEY, + }; + } + else if (process.env.OPENAI_API_KEY) { + openAiConfig = { + apiKey: process.env.OPENAI_API_KEY, + }; + } + else { + console.error('OpenAI API key missing! Make sure you set OPENAI_API_KEY and OPENAI_ORG_ID (optional) environment variables.'); + process.exit(1); + } + + this.openai = new OpenAIApi(openAiConfig); + } + + async sendRequest(turns, systemMessage, stop_seq='***') { + + let messages = [{'role': 'system', 'content': systemMessage}].concat(turns); + + let res = null; + try { + console.log('Awaiting openai api response...') + let completion = await this.openai.chat.completions.create({ + model: settings.model, + messages: messages, + stop: stop_seq, + }); + if (completion.choices[0].finish_reason == 'length') + throw new Error('Context length exceeded'); + console.log('Received.') + res = completion.choices[0].message.content; + } + catch (err) { + if ((err.message == 'Context length exceeded' || err.code == 'context_length_exceeded') && turns.length > 1) { + console.log('Context length exceeded, trying again with shorter context.'); + return await sendRequest(turns.slice(1), systemMessage, stop_seq); + } else { + console.log(err); + res = 'My brain disconnected, try again.'; + } + } + return res; + } + + async embed(text) { + const embedding = await this.openai.embeddings.create({ + model: "text-embedding-ada-002", + input: text, + encoding_format: "float", + }); + return embedding.data[0].embedding; + } +} + + + diff --git a/src/models/model.js b/src/models/model.js new file mode 100644 index 0000000..06a26e6 --- /dev/null +++ b/src/models/model.js @@ -0,0 +1,19 @@ +import { GPT } from './gpt.js'; +import { Gemini } from './gemini.js'; +import settings from '../settings.js'; + +console.log('Initializing model...'); +let model = null; +if (settings.model.includes('gemini')) { + model = new Gemini(); +} else { + model = new GPT(); +} + +export async function sendRequest(turns, systemMessage) { + return await model.sendRequest(turns, systemMessage); +} + +export async function embed(text) { + return await model.embed(text); +} \ No newline at end of file diff --git a/src/utils/examples.js b/src/utils/examples.js index 3e6a52c..6739517 100644 --- a/src/utils/examples.js +++ b/src/utils/examples.js @@ -1,6 +1,7 @@ import { readFileSync } from 'fs'; -import { embed, cosineSimilarity } from './gpt.js'; +import { cosineSimilarity } from './math.js'; import { stringifyTurns } from './text.js'; +import { embed } from '../models/model.js'; export class Examples { diff --git a/src/utils/gpt.js b/src/utils/gpt.js deleted file mode 100644 index b5281c7..0000000 --- a/src/utils/gpt.js +++ /dev/null @@ -1,75 +0,0 @@ -import OpenAIApi from 'openai'; - - -let openAiConfig = null; -if (process.env.OPENAI_ORG_ID) { - openAiConfig = { - organization: process.env.OPENAI_ORG_ID, - apiKey: process.env.OPENAI_API_KEY, - }; -} -else if (process.env.OPENAI_API_KEY) { - openAiConfig = { - apiKey: process.env.OPENAI_API_KEY, - }; -} -else { - console.error('OpenAI API key missing! Make sure you set OPENAI_API_KEY and OPENAI_ORG_ID (optional) environment variables.'); - process.exit(1); -} - -const openai = new OpenAIApi(openAiConfig); - - -export async function sendRequest(turns, systemMessage, stop_seq='***') { - - let messages = [{'role': 'system', 'content': systemMessage}].concat(turns); - - let res = null; - try { - console.log('Awaiting openai api response...') - let completion = await openai.chat.completions.create({ - model: 'gpt-3.5-turbo', - messages: messages, - stop: stop_seq, - }); - if (completion.choices[0].finish_reason == 'length') - throw new Error('Context length exceeded'); - console.log('Received.') - res = completion.choices[0].message.content; - } - catch (err) { - if ((err.message == 'Context length exceeded' || err.code == 'context_length_exceeded') && turns.length > 1) { - console.log('Context length exceeded, trying again with shorter context.'); - return await sendRequest(turns.slice(1), systemMessage, stop_seq); - } else { - console.log(err); - res = 'My brain disconnected, try again.'; - } - } - return res; -} - - -export async function embed(text) { - const embedding = await openai.embeddings.create({ - model: "text-embedding-ada-002", - input: text, - encoding_format: "float", - }); - return embedding.data[0].embedding; -} - -export function cosineSimilarity(a, b) { - let dotProduct = 0; - let magnitudeA = 0; - let magnitudeB = 0; - for (let i = 0; i < a.length; i++) { - dotProduct += a[i] * b[i]; // calculate dot product - magnitudeA += Math.pow(a[i], 2); // calculate magnitude of a - magnitudeB += Math.pow(b[i], 2); // calculate magnitude of b - } - magnitudeA = Math.sqrt(magnitudeA); - magnitudeB = Math.sqrt(magnitudeB); - return dotProduct / (magnitudeA * magnitudeB); // calculate cosine similarity -} \ No newline at end of file diff --git a/src/utils/math.js b/src/utils/math.js new file mode 100644 index 0000000..6da44c3 --- /dev/null +++ b/src/utils/math.js @@ -0,0 +1,13 @@ +export function cosineSimilarity(a, b) { + let dotProduct = 0; + let magnitudeA = 0; + let magnitudeB = 0; + for (let i = 0; i < a.length; i++) { + dotProduct += a[i] * b[i]; // calculate dot product + magnitudeA += Math.pow(a[i], 2); // calculate magnitude of a + magnitudeB += Math.pow(b[i], 2); // calculate magnitude of b + } + magnitudeA = Math.sqrt(magnitudeA); + magnitudeB = Math.sqrt(magnitudeB); + return dotProduct / (magnitudeA * magnitudeB); // calculate cosine similarity +} \ No newline at end of file