fixed example sorting and empty embedding

This commit is contained in:
MaxRobinsonTheGreat 2024-04-27 18:48:53 -05:00
parent 40e067903e
commit 75a5072cdc

View file

@ -12,7 +12,7 @@ export class Examples {
turnsToText(turns) {
let messages = '';
for (let turn of turns) {
if (turn.role === 'user')
if (turn.role !== 'assistant')
messages += turn.content.substring(turn.content.indexOf(':')+1).trim() + '\n';
}
return messages.trim();
@ -22,29 +22,11 @@ export class Examples {
return text.replace(/[^a-zA-Z ]/g, '').toLowerCase().split(' ');
}
async getSimilarity(text1, text2) {
if (this.model !== null) {
let embeddings1 = null;
let embeddings2 = null;
if (this.embeddings[text1])
embeddings1 = this.embeddings[text1];
else
embeddings1 = await this.model.embed(text1);
if (this.embeddings[text2])
embeddings2 = this.embeddings[text2];
else
embeddings2 = await this.model.embed(text2);
return cosineSimilarity(embeddings1, embeddings2);
} else {
const words1 = this.getWords(text1);
const words2 = this.getWords(text2);
const intersection = words1.filter(word => words2.includes(word));
return intersection.length / (words1.length + words2.length - intersection.length);
}
wordOverlapScore(text1, text2) {
const words1 = this.getWords(text1);
const words2 = this.getWords(text2);
const intersection = words1.filter(word => words2.includes(word));
return intersection.length / (words1.length + words2.length - intersection.length);
}
async load(examples) {
@ -59,11 +41,20 @@ export class Examples {
async getRelevant(turns) {
let turn_text = this.turnsToText(turns);
this.examples.sort((a, b) =>
this.getSimilarity(turn_text, this.turnsToText(a)) -
this.getSimilarity(turn_text, this.turnsToText(b))
);
let selected = this.examples.slice(-this.select_num);
if (this.model !== null) {
let embedding = await this.model.embed(turn_text);
this.examples.sort((a, b) =>
cosineSimilarity(embedding, this.embeddings[this.turnsToText(b)]) -
cosineSimilarity(embedding, this.embeddings[this.turnsToText(a)])
);
}
else {
this.examples.sort((a, b) =>
this.wordOverlapScore(turn_text, this.turnsToText(b)) -
this.wordOverlapScore(turn_text, this.turnsToText(a))
);
}
let selected = this.examples.slice(0, this.select_num);
return JSON.parse(JSON.stringify(selected)); // deep copy
}