From 4dd2f245ffdfb2fac5a169c97f9c825e7aa04828 Mon Sep 17 00:00:00 2001 From: vinodkiran Date: Fri, 29 Dec 2023 20:35:42 +0530 Subject: [PATCH] Compression Retriever: Reciprocal Rank Fusion --- .../retrievers/RRFRetriever/RRFRetriever.ts | 84 ++++++++++++++++ .../RRFRetriever/ReciprocalRankFusion.ts | 95 +++++++++++++++++++ .../RRFRetriever/compressionRetriever.svg | 7 ++ 3 files changed, 186 insertions(+) create mode 100644 packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts create mode 100644 packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts create mode 100644 packages/components/nodes/retrievers/RRFRetriever/compressionRetriever.svg diff --git a/packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts b/packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts new file mode 100644 index 00000000..8d6d9d6f --- /dev/null +++ b/packages/components/nodes/retrievers/RRFRetriever/RRFRetriever.ts @@ -0,0 +1,84 @@ +import { INode, INodeData, INodeParams } from '../../../src/Interface' +import { BaseLanguageModel } from 'langchain/base_language' +import { ContextualCompressionRetriever } from 'langchain/retrievers/contextual_compression' +import { BaseRetriever } from 'langchain/schema/retriever' +import { ReciprocalRankFusion } from './ReciprocalRankFusion' +import { VectorStoreRetriever } from 'langchain/vectorstores/base' + +class RRFRetriever_Retrievers implements INode { + label: string + name: string + version: number + description: string + type: string + icon: string + category: string + baseClasses: string[] + inputs: INodeParams[] + badge: string + + constructor() { + this.label = 'Reciprocal Rank Fusion Retriever' + this.name = 'RRFRetriever' + this.version = 2.0 + this.type = 'RRFRetriever' + this.badge = 'NEW' + this.icon = 'compressionRetriever.svg' + this.category = 'Retrievers' + this.description = 'Reciprocal Rank Fusion to re-rank search results by multiple query generation.' + this.baseClasses = [this.type, 'BaseRetriever'] + this.inputs = [ + { + label: 'Base Retriever', + name: 'baseRetriever', + type: 'VectorStoreRetriever' + }, + { + label: 'Language Model', + name: 'model', + type: 'BaseLanguageModel' + }, + { + label: 'Query Count', + name: 'queryCount', + description: 'Number of synthetic queries to generate. Default to 4', + placeholder: '4', + type: 'number', + default: 4, + additionalParams: true, + optional: true + }, + { + label: 'Top K', + name: 'topK', + description: 'Number of top results to fetch. Default to the TopK of the Base Retriever', + placeholder: '0', + type: 'number', + default: 0, + additionalParams: true, + optional: true + } + ] + } + + async init(nodeData: INodeData): Promise { + const llm = nodeData.inputs?.model as BaseLanguageModel + const baseRetriever = nodeData.inputs?.baseRetriever as BaseRetriever + const queryCount = nodeData.inputs?.queryCount as string + const q = queryCount ? parseFloat(queryCount) : 4 + const topK = nodeData.inputs?.topK as string + let k = topK ? parseFloat(topK) : 4 + + if (k <= 0) { + k = (baseRetriever as VectorStoreRetriever).k + } + + const ragFusion = new ReciprocalRankFusion(llm, baseRetriever as VectorStoreRetriever, q, k) + return new ContextualCompressionRetriever({ + baseCompressor: ragFusion, + baseRetriever: baseRetriever + }) + } +} + +module.exports = { nodeClass: RRFRetriever_Retrievers } diff --git a/packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts b/packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts new file mode 100644 index 00000000..134d7c8a --- /dev/null +++ b/packages/components/nodes/retrievers/RRFRetriever/ReciprocalRankFusion.ts @@ -0,0 +1,95 @@ +import { BaseDocumentCompressor } from 'langchain/retrievers/document_compressors' +import { Document } from 'langchain/document' +import { Callbacks } from 'langchain/callbacks' +import { BaseLanguageModel } from 'langchain/base_language' +import { ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate } from 'langchain/prompts' +import { LLMChain } from 'langchain/chains' +import { VectorStoreRetriever } from 'langchain/vectorstores/base' + +export class ReciprocalRankFusion extends BaseDocumentCompressor { + private readonly llm: BaseLanguageModel + private readonly queryCount: number + private readonly topK: number + private baseRetriever: VectorStoreRetriever + constructor(llm: BaseLanguageModel, baseRetriever: VectorStoreRetriever, queryCount: number, topK: number) { + super() + this.queryCount = queryCount + this.llm = llm + this.baseRetriever = baseRetriever + this.topK = topK + } + async compressDocuments( + documents: Document>[], + query: string, + _?: Callbacks | undefined + ): Promise>[]> { + // avoid empty api call + if (documents.length === 0) { + return [] + } + const chatPrompt = ChatPromptTemplate.fromMessages([ + SystemMessagePromptTemplate.fromTemplate( + 'You are a helpful assistant that generates multiple search queries based on a single input query.' + ), + HumanMessagePromptTemplate.fromTemplate( + 'Generate multiple search queries related to: {input}. Provide these alternative questions separated by newlines, do not add any numbers.' + ), + HumanMessagePromptTemplate.fromTemplate('OUTPUT (' + this.queryCount + ' queries):') + ]) + const llmChain = new LLMChain({ + llm: this.llm, + prompt: chatPrompt + }) + const multipleQueries = await llmChain.call({ input: query }) + const queries = [] + queries.push(query) + multipleQueries.text.split('\n').map((q: string) => { + queries.push(q) + }) + console.log(JSON.stringify(queries)) + const docList: Document>[][] = [] + for (let i = 0; i < queries.length; i++) { + const resultOne = await this.baseRetriever.vectorStore.similaritySearch(queries[i], 5) + const docs: any[] = [] + resultOne.forEach((doc) => { + docs.push(doc) + }) + docList.push(docs) + } + + return this.reciprocalRankFunction(docList, 60) + } + + reciprocalRankFunction(docList: Document>[][], k: number): Document>[] { + docList.forEach((docs: Document>[]) => { + docs.forEach((doc: any, index: number) => { + let rank = index + 1 + if (doc.metadata.relevancy_score) { + doc.metadata.relevancy_score += 1 / (rank + k) + } else { + doc.metadata.relevancy_score = 1 / (rank + k) + } + }) + }) + const scoreArray: any[] = [] + docList.forEach((docs: Document>[]) => { + docs.forEach((doc: any) => { + scoreArray.push(doc.metadata.relevancy_score) + }) + }) + scoreArray.sort((a, b) => b - a) + const rerankedDocuments: Document>[] = [] + const seenScores: any[] = [] + scoreArray.forEach((score) => { + docList.forEach((docs) => { + docs.forEach((doc: any) => { + if (doc.metadata.relevancy_score === score && seenScores.indexOf(score) === -1) { + rerankedDocuments.push(doc) + seenScores.push(doc.metadata.relevancy_score) + } + }) + }) + }) + return rerankedDocuments.splice(0, this.topK) + } +} diff --git a/packages/components/nodes/retrievers/RRFRetriever/compressionRetriever.svg b/packages/components/nodes/retrievers/RRFRetriever/compressionRetriever.svg new file mode 100644 index 00000000..23c52d25 --- /dev/null +++ b/packages/components/nodes/retrievers/RRFRetriever/compressionRetriever.svg @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file