From 15afb8a2ddb2ce92b47ff08cfdba3a087d55eff3 Mon Sep 17 00:00:00 2001 From: Henry Date: Tue, 20 Feb 2024 18:23:39 +0800 Subject: [PATCH] update mrkl agents --- .../agents/MRKLAgentChat/MRKLAgentChat.ts | 40 ++- .../nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts | 8 +- packages/components/src/agents.ts | 124 +++++++- .../marketplaces/chatflows/ReAct Agent.json | 278 ++++++++++++------ 4 files changed, 345 insertions(+), 105 deletions(-) diff --git a/packages/components/nodes/agents/MRKLAgentChat/MRKLAgentChat.ts b/packages/components/nodes/agents/MRKLAgentChat/MRKLAgentChat.ts index 9dce98af..7328b986 100644 --- a/packages/components/nodes/agents/MRKLAgentChat/MRKLAgentChat.ts +++ b/packages/components/nodes/agents/MRKLAgentChat/MRKLAgentChat.ts @@ -1,12 +1,13 @@ import { flatten } from 'lodash' -import { AgentExecutor, createReactAgent } from 'langchain/agents' +import { AgentExecutor } from 'langchain/agents' import { pull } from 'langchain/hub' import { Tool } from '@langchain/core/tools' import type { PromptTemplate } from '@langchain/core/prompts' import { BaseChatModel } from '@langchain/core/language_models/chat_models' import { additionalCallbacks } from '../../../src/handler' -import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { FlowiseMemory, ICommonObject, IMessage, INode, INodeData, INodeParams } from '../../../src/Interface' import { getBaseClasses } from '../../../src/utils' +import { createReactAgent } from '../../../src/agents' class MRKLAgentChat_Agents implements INode { label: string @@ -18,11 +19,12 @@ class MRKLAgentChat_Agents implements INode { category: string baseClasses: string[] inputs: INodeParams[] + sessionId?: string - constructor() { + constructor(fields?: { sessionId?: string }) { this.label = 'ReAct Agent for Chat Models' this.name = 'mrklAgentChat' - this.version = 2.0 + this.version = 3.0 this.type = 'AgentExecutor' this.category = 'Agents' this.icon = 'agent.svg' @@ -39,8 +41,14 @@ class MRKLAgentChat_Agents implements INode { label: 'Chat Model', name: 'model', type: 'BaseChatModel' + }, + { + label: 'Memory', + name: 'memory', + type: 'BaseChatMemory' } ] + this.sessionId = fields?.sessionId } async init(): Promise { @@ -48,6 +56,7 @@ class MRKLAgentChat_Agents implements INode { } async run(nodeData: INodeData, input: string, options: ICommonObject): Promise { + const memory = nodeData.inputs?.memory as FlowiseMemory const model = nodeData.inputs?.model as BaseChatModel let tools = nodeData.inputs?.tools as Tool[] tools = flatten(tools) @@ -68,10 +77,25 @@ class MRKLAgentChat_Agents implements INode { const callbacks = await additionalCallbacks(nodeData, options) - const result = await executor.invoke({ - input, - callbacks - }) + const prevChatHistory = options.chatHistory + const chatHistory = ((await memory.getChatMessages(this.sessionId, false, prevChatHistory)) as IMessage[]) ?? [] + const chatHistoryString = chatHistory.map((hist) => hist.message).join('\\n') + + const result = await executor.invoke({ input, chat_history: chatHistoryString }, { callbacks }) + + await memory.addChatMessages( + [ + { + text: input, + type: 'userMessage' + }, + { + text: result?.output, + type: 'apiMessage' + } + ], + this.sessionId + ) return result?.output } diff --git a/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts b/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts index 3bd6ba1e..596d20c5 100644 --- a/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts +++ b/packages/components/nodes/agents/MRKLAgentLLM/MRKLAgentLLM.ts @@ -1,5 +1,5 @@ import { flatten } from 'lodash' -import { AgentExecutor, createReactAgent } from 'langchain/agents' +import { AgentExecutor } from 'langchain/agents' import { pull } from 'langchain/hub' import { Tool } from '@langchain/core/tools' import type { PromptTemplate } from '@langchain/core/prompts' @@ -7,6 +7,7 @@ import { BaseLanguageModel } from 'langchain/base_language' import { additionalCallbacks } from '../../../src/handler' import { getBaseClasses } from '../../../src/utils' import { ICommonObject, INode, INodeData, INodeParams } from '../../../src/Interface' +import { createReactAgent } from '../../../src/agents' class MRKLAgentLLM_Agents implements INode { label: string @@ -68,10 +69,7 @@ class MRKLAgentLLM_Agents implements INode { const callbacks = await additionalCallbacks(nodeData, options) - const result = await executor.invoke({ - input, - callbacks - }) + const result = await executor.invoke({ input }, { callbacks }) return result?.output } diff --git a/packages/components/src/agents.ts b/packages/components/src/agents.ts index c795784d..5e4bd9c8 100644 --- a/packages/components/src/agents.ts +++ b/packages/components/src/agents.ts @@ -3,12 +3,23 @@ import { ChainValues } from '@langchain/core/utils/types' import { AgentStep, AgentAction } from '@langchain/core/agents' import { BaseMessage, FunctionMessage, AIMessage } from '@langchain/core/messages' import { OutputParserException } from '@langchain/core/output_parsers' +import { BaseLanguageModel } from '@langchain/core/language_models/base' import { CallbackManager, CallbackManagerForChainRun, Callbacks } from '@langchain/core/callbacks/manager' -import { ToolInputParsingException, Tool } from '@langchain/core/tools' -import { Runnable } from '@langchain/core/runnables' +import { ToolInputParsingException, Tool, StructuredToolInterface } from '@langchain/core/tools' +import { Runnable, RunnableSequence, RunnablePassthrough } from '@langchain/core/runnables' import { Serializable } from '@langchain/core/load/serializable' +import { renderTemplate } from '@langchain/core/prompts' import { BaseChain, SerializedLLMChain } from 'langchain/chains' -import { AgentExecutorInput, BaseSingleActionAgent, BaseMultiActionAgent, RunnableAgent, StoppingMethod } from 'langchain/agents' +import { + CreateReactAgentParams, + AgentExecutorInput, + AgentActionOutputParser, + BaseSingleActionAgent, + BaseMultiActionAgent, + RunnableAgent, + StoppingMethod +} from 'langchain/agents' +import { formatLogToString } from 'langchain/agents/format_scratchpad/log' export const SOURCE_DOCUMENTS_PREFIX = '\n\n----FLOWISE_SOURCE_DOCUMENTS----\n\n' type AgentFinish = { @@ -647,3 +658,110 @@ export const formatAgentSteps = (steps: AgentStep[]): BaseMessage[] => return [new AIMessage(action.log)] } }) + +const renderTextDescription = (tools: StructuredToolInterface[]): string => { + return tools.map((tool) => `${tool.name}: ${tool.description}`).join('\n') +} + +export const createReactAgent = async ({ llm, tools, prompt }: CreateReactAgentParams) => { + const missingVariables = ['tools', 'tool_names', 'agent_scratchpad'].filter((v) => !prompt.inputVariables.includes(v)) + if (missingVariables.length > 0) { + throw new Error(`Provided prompt is missing required input variables: ${JSON.stringify(missingVariables)}`) + } + const toolNames = tools.map((tool) => tool.name) + const partialedPrompt = await prompt.partial({ + tools: renderTextDescription(tools), + tool_names: toolNames.join(', ') + }) + // TODO: Add .bind to core runnable interface. + const llmWithStop = (llm as BaseLanguageModel).bind({ + stop: ['\nObservation:'] + }) + const agent = RunnableSequence.from([ + RunnablePassthrough.assign({ + //@ts-ignore + agent_scratchpad: (input: { steps: AgentStep[] }) => formatLogToString(input.steps) + }), + partialedPrompt, + llmWithStop, + new ReActSingleInputOutputParser({ + toolNames + }) + ]) + return agent +} + +class ReActSingleInputOutputParser extends AgentActionOutputParser { + lc_namespace = ['langchain', 'agents', 'react'] + + private toolNames: string[] + private FINAL_ANSWER_ACTION = 'Final Answer:' + private FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE = 'Parsing LLM output produced both a final answer and a parse-able action:' + private FORMAT_INSTRUCTIONS = `Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Action: the action to take, should be one of [{tool_names}] +Action Input: the input to the action +Observation: the result of the action +... (this Thought/Action/Action Input/Observation can repeat N times) +Thought: I now know the final answer +Final Answer: the final answer to the original input question` + + constructor(fields: { toolNames: string[] }) { + super(...arguments) + this.toolNames = fields.toolNames + } + + /** + * Parses the given text into an AgentAction or AgentFinish object. If an + * output fixing parser is defined, uses it to parse the text. + * @param text Text to parse. + * @returns Promise that resolves to an AgentAction or AgentFinish object. + */ + async parse(text: string): Promise { + const includesAnswer = text.includes(this.FINAL_ANSWER_ACTION) + const regex = /Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)/ + const actionMatch = text.match(regex) + if (actionMatch) { + if (includesAnswer) { + throw new Error(`${this.FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE}: ${text}`) + } + + const action = actionMatch[1] + const actionInput = actionMatch[2] + const toolInput = actionInput.trim().replace(/"/g, '') + + return { + tool: action, + toolInput, + log: text + } + } + + if (includesAnswer) { + const finalAnswerText = text.split(this.FINAL_ANSWER_ACTION)[1].trim() + return { + returnValues: { + output: finalAnswerText + }, + log: text + } + } + + // Instead of throwing Error, we return a AgentFinish object + return { returnValues: { output: text }, log: text } + } + + /** + * Returns the format instructions as a string. If the 'raw' option is + * true, returns the raw FORMAT_INSTRUCTIONS. + * @param options Options for getting the format instructions. + * @returns Format instructions as a string. + */ + getFormatInstructions(): string { + return renderTemplate(this.FORMAT_INSTRUCTIONS, 'f-string', { + tool_names: this.toolNames.join(', ') + }) + } +} diff --git a/packages/server/marketplaces/chatflows/ReAct Agent.json b/packages/server/marketplaces/chatflows/ReAct Agent.json index 151583a2..0afd0115 100644 --- a/packages/server/marketplaces/chatflows/ReAct Agent.json +++ b/packages/server/marketplaces/chatflows/ReAct Agent.json @@ -5,11 +5,11 @@ "nodes": [ { "width": 300, - "height": 143, + "height": 142, "id": "calculator_1", "position": { - "x": 664.1366474718458, - "y": 123.16419000640141 + "x": 466.86432329033937, + "y": 230.0825123205457 }, "type": "customNode", "data": { @@ -36,66 +36,171 @@ "selected": false }, "positionAbsolute": { - "x": 664.1366474718458, - "y": 123.16419000640141 + "x": 466.86432329033937, + "y": 230.0825123205457 }, "selected": false, "dragging": false }, { - "width": 300, - "height": 277, - "id": "serper_0", + "id": "mrklAgentChat_0", "position": { - "x": 330.964079024626, - "y": 109.83185250619351 + "x": 905.8535326018256, + "y": 388.58312223652564 }, "type": "customNode", "data": { - "id": "serper_0", - "label": "Serper", - "version": 1, - "name": "serper", - "type": "Serper", - "baseClasses": ["Serper", "Tool", "StructuredTool"], - "category": "Tools", - "description": "Wrapper around Serper.dev - Google Search API", - "inputParams": [ + "id": "mrklAgentChat_0", + "label": "ReAct Agent for Chat Models", + "version": 3, + "name": "mrklAgentChat", + "type": "AgentExecutor", + "baseClasses": ["AgentExecutor", "BaseChain", "Runnable"], + "category": "Agents", + "description": "Agent that uses the ReAct logic to decide what action to take, optimized to be used with Chat Models", + "inputParams": [], + "inputAnchors": [ { - "label": "Connect Credential", - "name": "credential", - "type": "credential", - "credentialNames": ["serperApi"], - "id": "serper_0-input-credential-credential" + "label": "Allowed Tools", + "name": "tools", + "type": "Tool", + "list": true, + "id": "mrklAgentChat_0-input-tools-Tool" + }, + { + "label": "Chat Model", + "name": "model", + "type": "BaseChatModel", + "id": "mrklAgentChat_0-input-model-BaseChatModel" + }, + { + "label": "Memory", + "name": "memory", + "type": "BaseChatMemory", + "id": "mrklAgentChat_0-input-memory-BaseChatMemory" } ], - "inputAnchors": [], - "inputs": {}, + "inputs": { + "tools": ["{{calculator_1.data.instance}}", "{{serper_0.data.instance}}"], + "model": "{{chatOpenAI_0.data.instance}}", + "memory": "{{RedisBackedChatMemory_0.data.instance}}" + }, "outputAnchors": [ { - "id": "serper_0-output-serper-Serper|Tool|StructuredTool", - "name": "serper", - "label": "Serper", - "type": "Serper | Tool | StructuredTool" + "id": "mrklAgentChat_0-output-mrklAgentChat-AgentExecutor|BaseChain|Runnable", + "name": "mrklAgentChat", + "label": "AgentExecutor", + "description": "Agent that uses the ReAct logic to decide what action to take, optimized to be used with Chat Models", + "type": "AgentExecutor | BaseChain | Runnable" } ], "outputs": {}, "selected": false }, + "width": 300, + "height": 330, "selected": false, "positionAbsolute": { - "x": 330.964079024626, - "y": 109.83185250619351 + "x": 905.8535326018256, + "y": 388.58312223652564 }, "dragging": false }, { + "id": "RedisBackedChatMemory_0", + "position": { + "x": 473.108799702029, + "y": 401.8098683245926 + }, + "type": "customNode", + "data": { + "id": "RedisBackedChatMemory_0", + "label": "Redis-Backed Chat Memory", + "version": 2, + "name": "RedisBackedChatMemory", + "type": "RedisBackedChatMemory", + "baseClasses": ["RedisBackedChatMemory", "BaseChatMemory", "BaseMemory"], + "category": "Memory", + "description": "Summarizes the conversation and stores the memory in Redis server", + "inputParams": [ + { + "label": "Connect Credential", + "name": "credential", + "type": "credential", + "optional": true, + "credentialNames": ["redisCacheApi", "redisCacheUrlApi"], + "id": "RedisBackedChatMemory_0-input-credential-credential" + }, + { + "label": "Session Id", + "name": "sessionId", + "type": "string", + "description": "If not specified, a random id will be used. Learn more", + "default": "", + "additionalParams": true, + "optional": true, + "id": "RedisBackedChatMemory_0-input-sessionId-string" + }, + { + "label": "Session Timeouts", + "name": "sessionTTL", + "type": "number", + "description": "Omit this parameter to make sessions never expire", + "additionalParams": true, + "optional": true, + "id": "RedisBackedChatMemory_0-input-sessionTTL-number" + }, + { + "label": "Memory Key", + "name": "memoryKey", + "type": "string", + "default": "chat_history", + "additionalParams": true, + "id": "RedisBackedChatMemory_0-input-memoryKey-string" + }, + { + "label": "Window Size", + "name": "windowSize", + "type": "number", + "description": "Window of size k to surface the last k back-and-forth to use as memory.", + "additionalParams": true, + "optional": true, + "id": "RedisBackedChatMemory_0-input-windowSize-number" + } + ], + "inputAnchors": [], + "inputs": { + "sessionId": "", + "sessionTTL": "", + "memoryKey": "chat_history", + "windowSize": "" + }, + "outputAnchors": [ + { + "id": "RedisBackedChatMemory_0-output-RedisBackedChatMemory-RedisBackedChatMemory|BaseChatMemory|BaseMemory", + "name": "RedisBackedChatMemory", + "label": "RedisBackedChatMemory", + "description": "Summarizes the conversation and stores the memory in Redis server", + "type": "RedisBackedChatMemory | BaseChatMemory | BaseMemory" + } + ], + "outputs": {}, + "selected": false + }, "width": 300, - "height": 574, + "height": 328, + "selected": false, + "positionAbsolute": { + "x": 473.108799702029, + "y": 401.8098683245926 + }, + "dragging": false + }, + { "id": "chatOpenAI_0", "position": { - "x": -27.71074046118335, - "y": 243.62715178281059 + "x": 81.2222202723384, + "y": 59.395597724017364 }, "type": "customNode", "data": { @@ -282,73 +387,69 @@ "id": "chatOpenAI_0-output-chatOpenAI-ChatOpenAI|BaseChatModel|BaseLanguageModel|Runnable", "name": "chatOpenAI", "label": "ChatOpenAI", + "description": "Wrapper around OpenAI large language models that use the Chat endpoint", "type": "ChatOpenAI | BaseChatModel | BaseLanguageModel | Runnable" } ], "outputs": {}, "selected": false }, + "width": 300, + "height": 573, "selected": false, "positionAbsolute": { - "x": -27.71074046118335, - "y": 243.62715178281059 + "x": 81.2222202723384, + "y": 59.395597724017364 }, "dragging": false }, { - "width": 300, - "height": 280, - "id": "mrklAgentChat_0", + "id": "serper_0", "position": { - "x": 1090.2058867451212, - "y": 423.2174695788541 + "x": 466.4499611299051, + "y": -67.74721119468873 }, "type": "customNode", "data": { - "id": "mrklAgentChat_0", - "label": "ReAct Agent for Chat Models", + "id": "serper_0", + "label": "Serper", "version": 1, - "name": "mrklAgentChat", - "type": "AgentExecutor", - "baseClasses": ["AgentExecutor", "BaseChain", "Runnable"], - "category": "Agents", - "description": "Agent that uses the ReAct logic to decide what action to take, optimized to be used with Chat Models", - "inputParams": [], - "inputAnchors": [ + "name": "serper", + "type": "Serper", + "baseClasses": ["Serper", "Tool", "StructuredTool", "Runnable"], + "category": "Tools", + "description": "Wrapper around Serper.dev - Google Search API", + "inputParams": [ { - "label": "Allowed Tools", - "name": "tools", - "type": "Tool", - "list": true, - "id": "mrklAgentChat_0-input-tools-Tool" - }, - { - "label": "Language Model", - "name": "model", - "type": "BaseLanguageModel", - "id": "mrklAgentChat_0-input-model-BaseLanguageModel" + "label": "Connect Credential", + "name": "credential", + "type": "credential", + "credentialNames": ["serperApi"], + "id": "serper_0-input-credential-credential" } ], - "inputs": { - "tools": ["{{calculator_1.data.instance}}", "{{serper_0.data.instance}}"], - "model": "{{chatOpenAI_0.data.instance}}" - }, + "inputAnchors": [], + "inputs": {}, "outputAnchors": [ { - "id": "mrklAgentChat_0-output-mrklAgentChat-AgentExecutor|BaseChain|Runnable", - "name": "mrklAgentChat", - "label": "AgentExecutor", - "type": "AgentExecutor | BaseChain | Runnable" + "id": "serper_0-output-serper-Serper|Tool|StructuredTool|Runnable", + "name": "serper", + "label": "Serper", + "description": "Wrapper around Serper.dev - Google Search API", + "type": "Serper | Tool | StructuredTool | Runnable" } ], "outputs": {}, "selected": false }, + "width": 300, + "height": 276, + "selected": false, "positionAbsolute": { - "x": 1090.2058867451212, - "y": 423.2174695788541 + "x": 466.4499611299051, + "y": -67.74721119468873 }, - "selected": false + "dragging": false } ], "edges": [ @@ -358,32 +459,31 @@ "target": "mrklAgentChat_0", "targetHandle": "mrklAgentChat_0-input-tools-Tool", "type": "buttonedge", - "id": "calculator_1-calculator_1-output-calculator-Calculator|Tool|StructuredTool|BaseLangChain-mrklAgentChat_0-mrklAgentChat_0-input-tools-Tool", - "data": { - "label": "" - } + "id": "calculator_1-calculator_1-output-calculator-Calculator|Tool|StructuredTool|BaseLangChain-mrklAgentChat_0-mrklAgentChat_0-input-tools-Tool" }, { - "source": "serper_0", - "sourceHandle": "serper_0-output-serper-Serper|Tool|StructuredTool", + "source": "RedisBackedChatMemory_0", + "sourceHandle": "RedisBackedChatMemory_0-output-RedisBackedChatMemory-RedisBackedChatMemory|BaseChatMemory|BaseMemory", "target": "mrklAgentChat_0", - "targetHandle": "mrklAgentChat_0-input-tools-Tool", + "targetHandle": "mrklAgentChat_0-input-memory-BaseChatMemory", "type": "buttonedge", - "id": "serper_0-serper_0-output-serper-Serper|Tool|StructuredTool-mrklAgentChat_0-mrklAgentChat_0-input-tools-Tool", - "data": { - "label": "" - } + "id": "RedisBackedChatMemory_0-RedisBackedChatMemory_0-output-RedisBackedChatMemory-RedisBackedChatMemory|BaseChatMemory|BaseMemory-mrklAgentChat_0-mrklAgentChat_0-input-memory-BaseChatMemory" }, { "source": "chatOpenAI_0", "sourceHandle": "chatOpenAI_0-output-chatOpenAI-ChatOpenAI|BaseChatModel|BaseLanguageModel|Runnable", "target": "mrklAgentChat_0", - "targetHandle": "mrklAgentChat_0-input-model-BaseLanguageModel", + "targetHandle": "mrklAgentChat_0-input-model-BaseChatModel", "type": "buttonedge", - "id": "chatOpenAI_0-chatOpenAI_0-output-chatOpenAI-ChatOpenAI|BaseChatModel|BaseLanguageModel|Runnable-mrklAgentChat_0-mrklAgentChat_0-input-model-BaseLanguageModel", - "data": { - "label": "" - } + "id": "chatOpenAI_0-chatOpenAI_0-output-chatOpenAI-ChatOpenAI|BaseChatModel|BaseLanguageModel|Runnable-mrklAgentChat_0-mrklAgentChat_0-input-model-BaseChatModel" + }, + { + "source": "serper_0", + "sourceHandle": "serper_0-output-serper-Serper|Tool|StructuredTool|Runnable", + "target": "mrklAgentChat_0", + "targetHandle": "mrklAgentChat_0-input-tools-Tool", + "type": "buttonedge", + "id": "serper_0-serper_0-output-serper-Serper|Tool|StructuredTool|Runnable-mrklAgentChat_0-mrklAgentChat_0-input-tools-Tool" } ] }