From e4187787900f49d65cd2801e8daf10f65cd9361a Mon Sep 17 00:00:00 2001 From: MaxRobinsonTheGreat Date: Tue, 7 May 2024 15:08:22 -0500 Subject: [PATCH] improved replicate, fixed gemini, shared toSinglePrompt --- src/models/gemini.js | 21 +++---- src/models/helper.js | 14 +++++ src/models/replicate.js | 122 +++++++++++++++++----------------------- 3 files changed, 75 insertions(+), 82 deletions(-) create mode 100644 src/models/helper.js diff --git a/src/models/gemini.js b/src/models/gemini.js index c27d34e..504e3f6 100644 --- a/src/models/gemini.js +++ b/src/models/gemini.js @@ -1,5 +1,5 @@ import { GoogleGenerativeAI } from '@google/generative-ai'; - +import { toSinglePrompt } from './helper.js'; export class Gemini { constructor(model_name, url) { @@ -13,6 +13,7 @@ export class Gemini { } async sendRequest(turns, systemMessage) { + let model; if (this.url) { model = this.genAI.getGenerativeModel( {model: this.model_name || "gemini-pro"}, @@ -24,23 +25,19 @@ export class Gemini { ); } - const messages = [{'role': 'system', 'content': systemMessage}].concat(turns); - let prompt = ""; - let role = ""; - messages.forEach((message) => { - role = message.role; - if (role === 'assistant') role = 'model'; - prompt += `${role}: ${message.content}\n`; - }); - if (role !== "model") // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message - prompt += "model: "; + const stop_seq = '***'; + const prompt = toSinglePrompt(turns, systemMessage, stop_seq, 'model'); console.log(prompt) const result = await model.generateContent(prompt); const response = await result.response; - return response.text(); + const text = response.text(); + if (!text.includes(stop_seq)) return text; + const idx = text.indexOf(stop_seq); + return text.slice(0, idx); } async embed(text) { + let model; if (this.url) { model = this.genAI.getGenerativeModel( {model: this.model_name || "embedding-001"}, diff --git a/src/models/helper.js b/src/models/helper.js new file mode 100644 index 0000000..7b45fe1 --- /dev/null +++ b/src/models/helper.js @@ -0,0 +1,14 @@ +export function toSinglePrompt(turns, system=null, stop_seq='***', model_nickname='assistant') { + let messages = turns; + if (system) messages.unshift({role: 'system', content: system}); + let prompt = ""; + let role = ""; + messages.forEach((message) => { + role = message.role; + if (role === 'assistant') role = model_nickname; + prompt += `${role}: ${message.content}${stop_seq}`; + }); + if (role !== model_nickname) // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message + prompt += model_nickname + ": "; + return prompt; +} diff --git a/src/models/replicate.js b/src/models/replicate.js index 8ff22b4..d9f8382 100644 --- a/src/models/replicate.js +++ b/src/models/replicate.js @@ -1,81 +1,63 @@ import Replicate from 'replicate'; +import { toSinglePrompt } from './helper.js'; // llama, mistral export class ReplicateAPI { - constructor(model_name, url) { - this.model_name = model_name; - this.url = url; + constructor(model_name, url) { + this.model_name = model_name; + this.url = url; - if (!process.env.REPLICATE_API_KEY) { - throw new Error('Replicate API key missing! Make sure you set your REPLICATE_API_KEY environment variable.'); - } + if (this.url) { + console.warn('Replicate API does not support custom URLs. Ignoring provided URL.'); + } - this.replicate = new Replicate({ - auth: process.env.REPLICATE_API_KEY, - }); - } + if (!process.env.REPLICATE_API_KEY) { + throw new Error('Replicate API key missing! Make sure you set your REPLICATE_API_KEY environment variable.'); + } - async sendRequest(turns, systemMessage) { - if (this.url) { - - } + this.replicate = new Replicate({ + auth: process.env.REPLICATE_API_KEY, + }); + } - let prev_role = null; - let messages = []; - let filler = { role: 'user', content: '_' }; + async sendRequest(turns, systemMessage) { + const stop_seq = '***'; + let prompt_template; + const prompt = toSinglePrompt(turns, systemMessage, stop_seq); + if (this.model_name.includes('llama')) { // llama + prompt_template = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + } + else { // mistral + prompt_template = "[INST] {prompt} [/INST] " + } - for (let msg of turns) { - if (msg.role === 'system') { - msg.role = 'user'; - msg.content = 'SYSTEM: ' + msg.content; - } - if (msg.role === prev_role && msg.role === 'assistant') { - // insert empty user message to separate assistant messages - messages.push(filler); - messages.push(msg); - } else if (msg.role === prev_role) { - // combine new message with previous message instead of adding a new one - messages[messages.length - 1].content += '\n' + msg.content; - } else { - messages.push(msg); - } - prev_role = msg.role; - } + const input = { prompt, prompt_template }; + let res = null; + try { + console.log('Awaiting Replicate API response...'); + let result = ''; + for await (const event of this.replicate.stream(this.model_name, { input })) { + result += event; + if (result === '') break; + if (result.includes(stop_seq)) { + result = result.slice(0, result.indexOf(stop_seq)); + break; + } + } + res = result; + } catch (err) { + console.log(err); + res = 'My brain disconnected, try again.'; + } + console.log('Received.'); + return res; + } - const prompt = '\n\n' + messages.map(msg => `${msg.role}: ${msg.content}`).join('\n'); - const input = { - prompt: prompt, - top_p: 0.95, - prompt_template: "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n"+systemMessage+"<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", - presence_penalty: 0, - frequency_penalty: 0 - }; - - - let res = null; - try { - console.log('Awaiting Replicate API response...'); - console.log('Input:', input); - let result = ''; - for await (const event of this.replicate.stream(this.model_name, { input })) { - result += event; - } - console.log('Received.'); - res = result; - } catch (err) { - console.log(err); - res = 'My brain disconnected, try again.'; - } - - return res; - } - "You are a playful Minecraft bot named andy that can converse with players, see, move, mine, build, and interact with the world by using commands. Act human-like as if you were a typical Minecraft player, rather than an AI. Be very brief in your responses, don't apologize constantly, don't give instructions or make lists unless asked, and don't refuse requests. Don't pretend to act, use commands immediately when requested. Do NOT say this: 'Sure, I've stopped.', instead say this: 'Sure, I'll sto…ll automatically choose a goal.\nquantity: (number) The quantity of the goal to set. Default is 1.\n*\n\nExamples of how to respond:\nExample 1:\nUser input: miner_32: Hey! What are you up to?\nYour output:\nNothing much miner_32, what do you need?\n\nExample 2:\nUser input: grombo_Xx: What do you see?\nYour output:\nLet me see... !nearbyBlocks\nSystem output: NEARBY_BLOCKS\n- oak_log\n- dirt\n- cobblestone\nYour output:\nI see some oak logs, dirt, and cobblestone.\n\n\nConversation Begin:\n\nuser: SYSTEM: SAY HELLO." - - async embed(text) { - const output = await this.replicate.run( - this.model_name || "mark3labs/embeddings-gte-base:d619cff29338b9a37c3d06605042e1ff0594a8c3eff0175fd6967f5643fc4d47", - { input: {text} } - ); - return output; - } + async embed(text) { + const output = await this.replicate.run( + this.model_name || "mark3labs/embeddings-gte-base:d619cff29338b9a37c3d06605042e1ff0594a8c3eff0175fd6967f5643fc4d47", + { input: {text} } + ); + return output.vectors; + } } \ No newline at end of file