Merge pull request #101 from kolbytn/fix-ollama

Fix ollama
This commit is contained in:
Max Robinson 2024-06-02 09:26:47 -05:00 committed by GitHub
commit bd73aa5a68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 51 additions and 37 deletions

View file

@ -1,5 +1,5 @@
import { writeFile, readFile, mkdirSync } from 'fs'; import { writeFile, readFile, mkdirSync } from 'fs';
import settings from '../../settings.js';
export class Coder { export class Coder {
constructor(agent) { constructor(agent) {

View file

@ -51,8 +51,12 @@ export class Prompter {
throw new Error('Unknown API:', api); throw new Error('Unknown API:', api);
let embedding = this.prompts.embedding; let embedding = this.prompts.embedding;
if (embedding === undefined) if (embedding === undefined) {
embedding = {api: chat.api}; if (chat.api !== 'ollama')
embedding = {api: chat.api};
else
embedding = {api: 'none'};
}
else if (typeof embedding === 'string' || embedding instanceof String) else if (typeof embedding === 'string' || embedding instanceof String)
embedding = {api: embedding}; embedding = {api: embedding};

View file

@ -1,4 +1,5 @@
import Anthropic from '@anthropic-ai/sdk'; import Anthropic from '@anthropic-ai/sdk';
import { strictFormat } from '../utils/text.js';
import { getKey } from '../utils/keys.js'; import { getKey } from '../utils/keys.js';
export class Claude { export class Claude {
@ -15,36 +16,7 @@ export class Claude {
} }
async sendRequest(turns, systemMessage) { async sendRequest(turns, systemMessage) {
let prev_role = null; const messages = strictFormat(turns);
let messages = [];
let filler = {role: 'user', content: '_'};
for (let msg of turns) {
if (msg.role === 'system') {
msg.role = 'user';
msg.content = 'SYSTEM: ' + msg.content;
}
if (msg.role === prev_role && msg.role === 'assistant') {
// insert empty user message to separate assistant messages
messages.push(filler);
messages.push(msg);
}
else if (msg.role === prev_role) {
// combine new message with previous message instead of adding a new one
messages[messages.length-1].content += '\n' + msg.content;
}
else {
messages.push(msg);
}
prev_role = msg.role;
}
if (messages.length > 0 && messages[0].role !== 'user') {
messages.unshift(filler); // anthropic requires user message to start
}
if (messages.length === 0) {
messages.push(filler);
}
let res = null; let res = null;
try { try {
console.log('Awaiting anthropic api response...') console.log('Awaiting anthropic api response...')

View file

@ -25,9 +25,11 @@ export class Gemini {
const stop_seq = '***'; const stop_seq = '***';
const prompt = toSinglePrompt(turns, systemMessage, stop_seq, 'model'); const prompt = toSinglePrompt(turns, systemMessage, stop_seq, 'model');
console.log('Awaiting Google API response...');
const result = await model.generateContent(prompt); const result = await model.generateContent(prompt);
const response = await result.response; const response = await result.response;
const text = response.text(); const text = response.text();
console.log('Received.');
if (!text.includes(stop_seq)) return text; if (!text.includes(stop_seq)) return text;
const idx = text.indexOf(stop_seq); const idx = text.indexOf(stop_seq);
return text.slice(0, idx); return text.slice(0, idx);

View file

@ -1,3 +1,5 @@
import { strictFormat } from '../utils/text.js';
export class Local { export class Local {
constructor(model_name, url) { constructor(model_name, url) {
this.model_name = model_name; this.model_name = model_name;
@ -8,12 +10,11 @@ export class Local {
async sendRequest(turns, systemMessage) { async sendRequest(turns, systemMessage) {
let model = this.model_name || 'llama3'; let model = this.model_name || 'llama3';
let messages = [{'role': 'system', 'content': systemMessage}].concat(turns); let messages = strictFormat(turns);
messages.unshift({role: 'system', content: systemMessage});
let res = null; let res = null;
try { try {
console.log(`Awaiting local response... (model: ${model})`) console.log(`Awaiting local response... (model: ${model})`)
console.log('Messages:', messages);
res = await this.send(this.chat_endpoint, {model: model, messages: messages, stream: false}); res = await this.send(this.chat_endpoint, {model: model, messages: messages, stream: false});
if (res) if (res)
res = res['message']['content']; res = res['message']['content'];
@ -56,4 +57,4 @@ export class Local {
} }
return data; return data;
} }
} }

View file

@ -24,4 +24,39 @@ export function toSinglePrompt(turns, system=null, stop_seq='***', model_nicknam
if (role !== model_nickname) // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message if (role !== model_nickname) // if the last message was from the user/system, add a prompt for the model. otherwise, pretend we are extending the model's own message
prompt += model_nickname + ": "; prompt += model_nickname + ": ";
return prompt; return prompt;
}
// ensures stricter turn order for anthropic/llama models
// combines repeated messages from the same role, separates repeat assistant messages with filler user messages
export function strictFormat(turns) {
let prev_role = null;
let messages = [];
let filler = {role: 'user', content: '_'};
for (let msg of turns) {
if (msg.role === 'system') {
msg.role = 'user';
msg.content = 'SYSTEM: ' + msg.content;
}
if (msg.role === prev_role && msg.role === 'assistant') {
// insert empty user message to separate assistant messages
messages.push(filler);
messages.push(msg);
}
else if (msg.role === prev_role) {
// combine new message with previous message instead of adding a new one
messages[messages.length-1].content += '\n' + msg.content;
}
else {
messages.push(msg);
}
prev_role = msg.role;
}
if (messages.length > 0 && messages[0].role !== 'user') {
messages.unshift(filler); // anthropic requires user message to start
}
if (messages.length === 0) {
messages.push(filler);
}
return messages;
} }