diff --git a/src/models/gemini.js b/src/models/gemini.js index 504e3f6..61f4f1a 100644 --- a/src/models/gemini.js +++ b/src/models/gemini.js @@ -1,5 +1,5 @@ import { GoogleGenerativeAI } from '@google/generative-ai'; -import { toSinglePrompt } from './helper.js'; +import { toSinglePrompt } from '../utils/text.js'; export class Gemini { constructor(model_name, url) { @@ -27,7 +27,6 @@ export class Gemini { const stop_seq = '***'; const prompt = toSinglePrompt(turns, systemMessage, stop_seq, 'model'); - console.log(prompt) const result = await model.generateContent(prompt); const response = await result.response; const text = response.text(); diff --git a/src/models/helper.js b/src/models/helper.js deleted file mode 100644 index 7b45fe1..0000000 --- a/src/models/helper.js +++ /dev/null @@ -1,14 +0,0 @@ -export function toSinglePrompt(turns, system=null, stop_seq='***', model_nickname='assistant') { - let messages = turns; - if (system) messages.unshift({role: 'system', content: system}); - let prompt = ""; - let role = ""; - messages.forEach((message) => { - role = message.role; - if (role === 'assistant') role = model_nickname; - prompt += `${role}: ${message.content}${stop_seq}`; - }); - if (role !== model_nickname) // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message - prompt += model_nickname + ": "; - return prompt; -} diff --git a/src/models/replicate.js b/src/models/replicate.js index d9f8382..4301448 100644 --- a/src/models/replicate.js +++ b/src/models/replicate.js @@ -1,5 +1,5 @@ import Replicate from 'replicate'; -import { toSinglePrompt } from './helper.js'; +import { toSinglePrompt } from '../utils/text.js'; // llama, mistral export class ReplicateAPI { @@ -24,7 +24,8 @@ export class ReplicateAPI { const stop_seq = '***'; let prompt_template; const prompt = toSinglePrompt(turns, systemMessage, stop_seq); - if (this.model_name.includes('llama')) { // llama + let model_name = this.model_name || 'meta/meta-llama-3-70b-instruct'; + if (model_name.includes('llama')) { // llama prompt_template = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" } else { // mistral @@ -36,7 +37,7 @@ export class ReplicateAPI { try { console.log('Awaiting Replicate API response...'); let result = ''; - for await (const event of this.replicate.stream(this.model_name, { input })) { + for await (const event of this.replicate.stream(model_name, { input })) { result += event; if (result === '') break; if (result.includes(stop_seq)) { diff --git a/src/utils/text.js b/src/utils/text.js index d06221a..c075d50 100644 --- a/src/utils/text.js +++ b/src/utils/text.js @@ -11,4 +11,19 @@ export function stringifyTurns(turns) { } } return res.trim(); +} + +export function toSinglePrompt(turns, system=null, stop_seq='***', model_nickname='assistant') { + let messages = turns; + if (system) messages.unshift({role: 'system', content: system}); + let prompt = ""; + let role = ""; + messages.forEach((message) => { + role = message.role; + if (role === 'assistant') role = model_nickname; + prompt += `${role}: ${message.content}${stop_seq}`; + }); + if (role !== model_nickname) // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message + prompt += model_nickname + ": "; + return prompt; } \ No newline at end of file