安装
pip install langgraph-checkpoint-sqlite
异步checkpiont初始化:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
conn = aiosqlite.connect(":memory:", check_same_thread=False)
memory = AsyncSqliteSaver(conn)
如果使用异步流式应对,需要确保llm节点或者相关节点也转成异步化操作
async def llm(self, state: AgentState):llm_msgs = state['messages']if self.systemMessage:llm_msgs = self.systemMessage + state['messages']print(f'ask llm to handle request msg, msg: {llm_msgs}')try:# 关键修复:await 异步调用并直接获取结果msg = await self.model.ainvoke(llm_msgs)print(f'msg={msg}')return {'messages': [msg]} # 确保返回的是消息对象而非协程except Exception as e:print(f"Model invocation error: {e}")# 返回错误提示消息(需符合Message类型)from langchain_core.messages import AIMessagereturn {'messages': [AIMessage(content=f"Error: {str(e)}")]}async def take_action_tool(self, state: AgentState):current_tools: List[ToolCall] = state['messages'][-1].tool_callsresults = []for t in current_tools:tool_result = await self.tools[t['name']].ainvoke(t['args'])results.append(ToolMessage(tool_call_id=t['id'],content=str(tool_result),name=t['name'],))print(f'Back to model')return {'messages': results}
最后的完整代码如下:
import asyncio
from typing import Annotated, List, TypedDict
import osimport aiosqlite
from langchain_community.chat_models import ChatTongyi
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, ToolMessage, ToolCall
from dotenv import load_dotenv
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.tools import BaseTool
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
from langgraph.constants import END, START
from langgraph.graph import add_messages, StateGraphconn = aiosqlite.connect(":memory:", check_same_thread=False)
load_dotenv(dotenv_path='../keys.env')
ts_tool = TavilySearchResults(max_results=2)class AgentState(TypedDict):messages: Annotated[List[AnyMessage], add_messages]class Agent:def __init__(self,model: BaseChatModel,systemMessage: List[SystemMessage],tools: List[BaseTool],memory,):assert all(isinstance(t, BaseTool) for t in tools), 'tools must implement BASEcALL'graph = StateGraph(AgentState)graph.add_node('llm', self.llm)graph.add_node('take_action_tool', self.take_action_tool)graph.add_conditional_edges('llm',self.exist_action,{True: 'take_action_tool',False: END})graph.set_entry_point('llm')graph.add_edge('take_action_tool', 'llm')self.app = graph.compile(checkpointer=memory)self.tools = {t.name: t for t in tools}self.model = model.bind_tools(tools)self.systemMessage = systemMessagedef exist_action(self, state: AgentState):tool_calls = state['messages'][-1].tool_callsprint(f'tool_calls size {len(tool_calls)}')return len(tool_calls) > 0async def llm(self, state: AgentState):llm_msgs = state['messages']if self.systemMessage:llm_msgs = self.systemMessage + state['messages']print(f'ask llm to handle request msg, msg: {llm_msgs}')try:# 关键修复:await 异步调用并直接获取结果msg = await self.model.ainvoke(llm_msgs)print(f'msg={msg}')return {'messages': [msg]} # 确保返回的是消息对象而非协程except Exception as e:print(f"Model invocation error: {e}")# 返回错误提示消息(需符合Message类型)from langchain_core.messages import AIMessagereturn {'messages': [AIMessage(content=f"Error: {str(e)}")]}async def take_action_tool(self, state: AgentState):current_tools: List[ToolCall] = state['messages'][-1].tool_callsresults = []for t in current_tools:tool_result = await self.tools[t['name']].ainvoke(t['args'])results.append(ToolMessage(tool_call_id=t['id'],content=str(tool_result),name=t['name'],))print(f'Back to model')return {'messages': results}async def work():prompt = """You are a smart research assistant. Use the search engine to look up information. \You are allowed to make multiple calls (either together or in sequence). \Only look up information when you are sure of what you want. \If you need to look up some information before asking a follow up question, you are allowed to do that!"""qwen_model = ChatTongyi(model=os.getenv('model'),api_key=os.getenv('api_key'),base_url=os.getenv('base_url'),) # reduce inference costmemory = AsyncSqliteSaver(conn)agent = Agent(model=qwen_model, tools=[ts_tool], systemMessage=[SystemMessage(content=prompt)], memory=memory)messages = [HumanMessage("who is the popular football star in the world?")]configurable = {"configurable": {"thread_id": "1"}}async for event in agent.app.astream_events({"messages": messages}, configurable, version="v1"):kind = event["event"]# print(f"kind = {kind}")if kind == "on_chat_model_stream":content = event["data"]["chunk"].contentif content:# Empty content in the context of OpenAI means# that the model is asking for a tool to be invoked.# So we only print non-empty contentprint(content, end="|")if __name__ == '__main__':asyncio.run(work())