refactored and added defaul replicate

This commit is contained in:
MaxRobinsonTheGreat 2024-05-10 13:41:29 -05:00
parent e418778790
commit 0bd92f7521
4 changed files with 20 additions and 19 deletions

View file

@ -1,5 +1,5 @@
import { GoogleGenerativeAI } from '@google/generative-ai';
import { toSinglePrompt } from './helper.js';
import { toSinglePrompt } from '../utils/text.js';
export class Gemini {
constructor(model_name, url) {
@ -27,7 +27,6 @@ export class Gemini {
const stop_seq = '***';
const prompt = toSinglePrompt(turns, systemMessage, stop_seq, 'model');
console.log(prompt)
const result = await model.generateContent(prompt);
const response = await result.response;
const text = response.text();

View file

@ -1,14 +0,0 @@
export function toSinglePrompt(turns, system=null, stop_seq='***', model_nickname='assistant') {
let messages = turns;
if (system) messages.unshift({role: 'system', content: system});
let prompt = "";
let role = "";
messages.forEach((message) => {
role = message.role;
if (role === 'assistant') role = model_nickname;
prompt += `${role}: ${message.content}${stop_seq}`;
});
if (role !== model_nickname) // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message
prompt += model_nickname + ": ";
return prompt;
}

View file

@ -1,5 +1,5 @@
import Replicate from 'replicate';
import { toSinglePrompt } from './helper.js';
import { toSinglePrompt } from '../utils/text.js';
// llama, mistral
export class ReplicateAPI {
@ -24,7 +24,8 @@ export class ReplicateAPI {
const stop_seq = '***';
let prompt_template;
const prompt = toSinglePrompt(turns, systemMessage, stop_seq);
if (this.model_name.includes('llama')) { // llama
let model_name = this.model_name || 'meta/meta-llama-3-70b-instruct';
if (model_name.includes('llama')) { // llama
prompt_template = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
}
else { // mistral
@ -36,7 +37,7 @@ export class ReplicateAPI {
try {
console.log('Awaiting Replicate API response...');
let result = '';
for await (const event of this.replicate.stream(this.model_name, { input })) {
for await (const event of this.replicate.stream(model_name, { input })) {
result += event;
if (result === '') break;
if (result.includes(stop_seq)) {

View file

@ -11,4 +11,19 @@ export function stringifyTurns(turns) {
}
}
return res.trim();
}
export function toSinglePrompt(turns, system=null, stop_seq='***', model_nickname='assistant') {
let messages = turns;
if (system) messages.unshift({role: 'system', content: system});
let prompt = "";
let role = "";
messages.forEach((message) => {
role = message.role;
if (role === 'assistant') role = model_nickname;
prompt += `${role}: ${message.content}${stop_seq}`;
});
if (role !== model_nickname) // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message
prompt += model_nickname + ": ";
return prompt;
}