From 90df61d2be369192523e2522a804cb25f1363b7b Mon Sep 17 00:00:00 2001 From: Qu Yi Date: Fri, 1 Nov 2024 01:08:30 +0800 Subject: [PATCH] Sort docs by relevance to !newAction("task") --- src/agent/library/index.js | 19 ++++++------ src/agent/prompter.js | 60 +++++++++++++++++++++++++++----------- 2 files changed, 53 insertions(+), 26 deletions(-) diff --git a/src/agent/library/index.js b/src/agent/library/index.js index 677dc11..ae864b0 100644 --- a/src/agent/library/index.js +++ b/src/agent/library/index.js @@ -3,20 +3,21 @@ import * as world from './world.js'; export function docHelper(functions, module_name) { - let docstring = ''; + let docArray = []; for (let skillFunc of functions) { let str = skillFunc.toString(); - if (str.includes('/**')){ - docstring += module_name+'.'+skillFunc.name; - docstring += str.substring(str.indexOf('/**')+3, str.indexOf('**/')) + '\n'; + if (str.includes('/**')) { + let docEntry = `${module_name}.${skillFunc.name}\n`; + docEntry += str.substring(str.indexOf('/**') + 3, str.indexOf('**/')).trim(); + docArray.push(docEntry); } } - return docstring; + return docArray; } export function getSkillDocs() { - let docstring = "\n*SKILL DOCS\nThese skills are javascript functions that can be called when writing actions and skills.\n"; - docstring += docHelper(Object.values(skills), 'skills'); - docstring += docHelper(Object.values(world), 'world'); - return docstring + '*\n'; + let docArray = []; + docArray = docArray.concat(docHelper(Object.values(skills), 'skills')); + docArray = docArray.concat(docHelper(Object.values(world), 'world')); + return docArray; } diff --git a/src/agent/prompter.js b/src/agent/prompter.js index 114064a..3ba51dd 100644 --- a/src/agent/prompter.js +++ b/src/agent/prompter.js @@ -1,17 +1,17 @@ -import { readFileSync, mkdirSync, writeFileSync} from 'fs'; -import { Examples } from '../utils/examples.js'; -import { getCommandDocs } from './commands/index.js'; -import { getSkillDocs } from './library/index.js'; -import { stringifyTurns } from '../utils/text.js'; -import { getCommand } from './commands/index.js'; +import {mkdirSync, readFileSync, writeFileSync} from 'fs'; +import {Examples} from '../utils/examples.js'; +import {getCommand, getCommandDocs} from './commands/index.js'; +import {getSkillDocs} from './library/index.js'; +import {stringifyTurns} from '../utils/text.js'; +import {cosineSimilarity} from '../utils/math.js'; -import { Gemini } from '../models/gemini.js'; -import { GPT } from '../models/gpt.js'; -import { Claude } from '../models/claude.js'; -import { ReplicateAPI } from '../models/replicate.js'; -import { Local } from '../models/local.js'; -import { GroqCloudAPI } from '../models/groq.js'; -import { HuggingFace } from '../models/huggingface.js'; +import {Gemini} from '../models/gemini.js'; +import {GPT} from '../models/gpt.js'; +import {Claude} from '../models/claude.js'; +import {ReplicateAPI} from '../models/replicate.js'; +import {Local} from '../models/local.js'; +import {GroqCloudAPI} from '../models/groq.js'; +import {HuggingFace} from '../models/huggingface.js'; export class Prompter { constructor(agent, fp) { @@ -19,7 +19,8 @@ export class Prompter { this.profile = JSON.parse(readFileSync(fp, 'utf8')); this.convo_examples = null; this.coding_examples = null; - + this.skill_docs_embeddings = {}; + let name = this.profile.name; let chat = this.profile.model; this.cooldown = this.profile.cooldown ? this.profile.cooldown : 0; @@ -111,16 +112,41 @@ export class Prompter { async initExamples() { // Using Promise.all to implement concurrent processing - // Create Examples instances this.convo_examples = new Examples(this.embedding_model); this.coding_examples = new Examples(this.embedding_model); - // Use Promise.all to load examples concurrently + let skill_docs = getSkillDocs(); await Promise.all([ this.convo_examples.load(this.profile.conversation_examples), this.coding_examples.load(this.profile.coding_examples), + ...skill_docs.map(async (doc) => { + let func_name_desc = doc.split('\n').slice(0, 2).join(''); + this.skill_docs_embeddings[doc] = await this.embedding_model.embed([func_name_desc]); + }), ]); } + async getRelevantSkillDocs(messages, select_num) { + let latest_message_content = messages.slice().reverse().find(msg => msg.role !== 'system')?.content || ''; + let latest_message_embedding = await this.embedding_model.embed([latest_message_content]); + + let skill_doc_similarities = Object.keys(this.skill_docs_embeddings) + .map(doc_key => ({ + doc_key, + similarity_score: cosineSimilarity(latest_message_embedding, this.skill_docs_embeddings[doc_key]) + })) + .sort((a, b) => b.similarity_score - a.similarity_score); + + // select_num = -1 means select all + let selected_docs = skill_doc_similarities.slice(0, select_num === -1 ? skill_doc_similarities.length : select_num); + let message = '\nThe following recommended functions are listed in descending order of task relevance.\nSkillDocs:\n'; + message += selected_docs.map(doc => `${doc.doc_key}`).join('\n'); + return message; + } + + + + + async replaceStrings(prompt, messages, examples=null, to_summarize=[], last_goals=null) { prompt = prompt.replaceAll('$NAME', this.agent.name); @@ -135,7 +161,7 @@ export class Prompter { if (prompt.includes('$COMMAND_DOCS')) prompt = prompt.replaceAll('$COMMAND_DOCS', getCommandDocs()); if (prompt.includes('$CODE_DOCS')) - prompt = prompt.replaceAll('$CODE_DOCS', getSkillDocs()); + prompt = prompt.replaceAll('$CODE_DOCS', this.getRelevantSkillDocs(messages, -1)); if (prompt.includes('$EXAMPLES') && examples !== null) prompt = prompt.replaceAll('$EXAMPLES', await examples.createExampleMessage(messages)); if (prompt.includes('$MEMORY'))