langchain源码研究 - deepagents设计思想学习
DeepAgents 是一个开箱即用的智能代理框架,提供预设的中间件、工具和上下文管理功能。通过 create_deep_agent 方法可快速创建代理,支持模型配置、系统提示词定制、子代理调用、技能加载等功能。框架内置了待办事项、文件操作、长文本摘要等中间件,并提供状态持久化、人工干预机制和调试模式。底层基于 LangChain 的 create_agent 实现,简化了智能代理的开发流程,适用
https://github.com/langchain-ai/deepagents
目录
what
什么是deepagent?
Deep Agents is an agent harness. An opinionated, ready-to-run agent out of the box. Instead of wiring up prompts, tools, and context management yourself, you get a working agent immediately and customize what you need.
- 按照
github官方解释,它是一个agent harness,是一个开箱即用的,可以自己配置prompt、tool和上下文管理的工作agent,
源码分析
1. 初始化
1.1 create_deep_agent
- 先分析
create_deep_agent这个方法
def create_deep_agent( # noqa: C901, PLR0912 # Complex graph assembly logic with many conditional branches
model: str | BaseChatModel | None = None,
tools: Sequence[BaseTool | Callable | dict[str, Any]] | None = None,
*,
system_prompt: str | SystemMessage | None = None,
middleware: Sequence[AgentMiddleware] = (),
subagents: list[SubAgent | CompiledSubAgent] | None = None,
skills: list[str] | None = None,
memory: list[str] | None = None,
response_format: ResponseFormat | None = None,
context_schema: type[Any] | None = None,
checkpointer: Checkpointer | None = None,
store: BaseStore | None = None,
backend: BackendProtocol | BackendFactory | None = None,
interrupt_on: dict[str, bool | InterruptOnConfig] | None = None,
debug: bool = False,
name: str | None = None,
cache: BaseCache | None = None,
) -> CompiledStateGraph:
首先看一下这个函数传入的这些参数,model指的是要使用的模型,tools是可调用的工具,system_prompt就是这个Agent的系统提示词,比如人设之类的
-
middleware指的是额外中间件,这个会在基础中间件之后、AnthropicPromptCachingMiddleware和MemoryMiddleware之前应用。框架提供了一些基础的中间件,TodoListMiddleware,作用是模型在执行任务之前会写一些待办事项。FilesystemMiddleware提供了一些操作文件系统的方法(需要用户提供sandbox方法才能使用),create_summarization_middleware长上下文摘要中间件,PatchToolCallsMiddleware修复悬空工具调用,AnthropicPromptCachingMiddleware提示词缓存,MemoryMiddleware记忆管理 -
subAgents作用是可供Agent调用的子Agent规范,它支持声明式同步自代理SubAgent、预编译的可运行自代理CompiledSubAgent和异步/后台子代理AsyncSubAgent -
skills作用是加载额外的skill文件,路径格式是POSIX(正斜杠),相对于backend的根目录
比如下面这种
agent = create_deep_agent(
skills=[
"/skills/user/", # 用户自定义技能
"/skills/project/", # 项目技能
]
)
memory作用是记忆文件的路径列表(AGENTS.md文件),当Agent启动时加载并添加到系统提示词response_format的作用是结构化输出响应格式,可以通过这个自定义json schema或者特定格式context_schema的作用是作为DeepAgent的上下文Schema定义状态结构checkpointer作用是持久化Agent状态,可以支持状态保存和恢复
# 比如下面这个例子
from langgraph.checkpoint.memory import MemorySaver
checkpointer = MemorySaver()
agent = create_deep_agent(
checkpointer=checkpointer
)
# 可以保存和恢复状态
config = {"configurable": {"thread_id": "session-1"}}
result = agent.invoke({"messages": ["你好"]}, config=config)
# 后续可以继续这个会话
result = agent.invoke({"messages": ["继续"]}, config=config)
store作用是持久化存储,可与存储键值对数据
from langgraph.store.base import BaseStore
store = BaseStore()
agent = create_deep_agent(
backend=StoreBackend(store=store),
store=store
)
backend作用是文件存储和执行后端,默认是StateBackend(内存),它支持一种SandboxBackendProtocol的类型,这种类型可支持命令执行,下面会展开讲讲interrupt_on的作用是在指定工具调用时暂停执行,可以等待人工批准或者修改,像下面这样
# 在所有编辑文件操作前暂停
agent = create_deep_agent(
interrupt_on={
"edit_file": True,
}
)
# 在特定工具前暂停,带配置
agent = create_deep_agent(
interrupt_on={
"execute": {
"allow_all": False,
"allow_edit": True,
}
}
)
debug作用是是否启用调试模式,当设置为True的时候,会把很详细的框架日志都打印出来,比如像下面这种
[DEBUG] Graph execution started
[DEBUG] Node 'agent' - before_agent middleware executing...
[DEBUG] PatchToolCallsMiddleware: checking for dangling tool calls
[DEBUG] SummarizationMiddleware: token count = 15000/200000 (7.5%)
[DEBUG] Node 'agent' - model request prepared
[DEBUG] System prompt: <agent_memory>...</agent_memory>
[DEBUG] Messages: [HumanMessage("帮我写一个函数")]
[DEBUG] Node 'agent' - model response received
[DEBUG] AIMessage(content="好的,我来帮你...")
[DEBUG] Node 'agent' - after_agent middleware executing...
[DEBUG] Node 'agent' completed
cache作用是Agent使用的缓存
然后我们来看具体的初始化逻辑
...
model = get_default_model() if model is None else resolve_model(model)
backend = backend if backend is not None else (StateBackend)
# Build general-purpose subagent with default middleware stack
# 这里注册了一些预制的工具,例如写todolist的工具,操作文件系统的工具(需要用户提供沙箱方法)
gp_middleware: list[AgentMiddleware[Any, Any, Any]] = [
TodoListMiddleware(),
FilesystemMiddleware(backend=backend),
create_summarization_middleware(model, backend),
PatchToolCallsMiddleware(),
]
if skills is not None:
gp_middleware.append(SkillsMiddleware(backend=backend, sources=skills))
gp_middleware.append(AnthropicPromptCachingMiddleware(unsupported_model_behavior="ignore"))
if interrupt_on is not None:
gp_middleware.append(HumanInTheLoopMiddleware(interrupt_on=interrupt_on))
...
# 最底层调用了langchain的create_agent方法,稍后讲解
return create_agent(
model,
system_prompt=final_system_prompt,
tools=tools,
middleware=deepagent_middleware,
response_format=response_format,
context_schema=context_schema,
checkpointer=checkpointer,
store=store,
debug=debug,
name=name,
cache=cache,
).with_config(
{
"recursion_limit": 10_000,
"metadata": {
"ls_integration": "deepagents",
"versions": {"deepagents": __version__},
"lc_agent_name": name,
},
}
)
可以在代码里看到,最后是调用create_agent来初始化的,而且到最后,subAgent、skills这些东西都转化为了deepagent_middleware,我们在后面讲解下middleware是如何工作的
TodoListMiddleware核心是定义的下面这个方法,description是工具的使用描述,这里可以学习下写tool desc的几个关键点,第一是哪些场景下要使用这个工具,第二是工具如何使用,第三是什么时候不要使用这个工具,最后给个总结。在这个todolist的工具场景下,他还讲了一段如何标记任务状态的逻辑,说明如果只需要调用很少的工具完成任务,且执行逻辑很清晰就不要使用这个todolist的工具了。具体细节可以自行查看源码
self.tools = [
StructuredTool.from_function(
name="write_todos",
description=tool_description,
func=_write_todos,
coroutine=_awrite_todos,
args_schema=WriteTodosInput,
infer_schema=False,
)
]
...
# 核心执行的方法就是下面这个
return Command(
update={
"todos": todos,
"messages": [
ToolMessage(f"Updated todo list to {todos}", tool_call_id=runtime.tool_call_id)
],
}
)
- 上面这个
Command就是会去更新下面Agent状态中的的todo列表,然后消息通过上面的ToolMessage这个对象拼接给模型最后的Message结构发送给模型,间接让模型了解
# Agent 的完整状态
agent_state = {
"messages": [ # 对话历史
HumanMessage("帮我重构代码库"),
AIMessage("我来帮你重构代码库"),
ToolMessage("Updated todo list to [...]", tool_call_id="call_abc123")
],
"todos": [ # todos 字段在这里
{
"content": "分析现有代码结构",
"status": "in_progress"
},
{
"content": "重构核心模块",
"status": "pending"
},
{
"content": "更新测试用例",
"status": "pending"
}
],
"files": {}, # 文件系统状态
"memory": [] # 记忆状态
}
FileSystemMiddleware提供了一些操作文件系统的工具方法,篇幅原因不展开讲了,只说一个execute工具,这个是需要用户提供sandbox方法的,这是出于安全考虑,如果没有提供,默认是不让Agent随意执行shell命令的,这段逻辑在下面这段代码中,假设上层的backend不支持沙盒执行,就不会把这个工具加入到最后的工具列表中
async def awrap_model_call(
self,
request: ModelRequest[ContextT],
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
) -> ModelResponse[ResponseT] | ExtendedModelResponse:
backend_supports_execution = False
if has_execute_tool:
# Resolve backend to check execution support
backend = self._get_backend(request.runtime) # ty: ignore[invalid-argument-type]
backend_supports_execution = _supports_execution(backend)
# If execute tool exists but backend doesn't support it, filter it out
if not backend_supports_execution:
filtered_tools = [tool for tool in request.tools if (tool.name if hasattr(tool, "name") else tool.get("name")) != "execute"]
request = request.override(tools=filtered_tools)
has_execute_tool = False
...
create_summarization_middleware的作用是对话摘要,它会实时计算消息历史的token数量,超过阈值的时候自动触发压缩,使用LLM总结旧消息,完整历史存储到backend,用摘要消息替换旧消息。下面是一段示例代码
# Get effective messages based on previous summarization events
effective_messages = self._get_effective_messages(request)
# Step 1: Truncate args if configured
truncated_messages, _ = self._truncate_args(
effective_messages,
request.system_message,
request.tools,
)
# Step 2: Check if summarization should happen
counted_messages = [request.system_message, *truncated_messages] if request.system_message is not None else truncated_messages
try:
total_tokens = self.token_counter(counted_messages, tools=request.tools) # ty: ignore[unknown-argument]
except TypeError:
total_tokens = self.token_counter(counted_messages)
should_summarize = self._should_summarize(truncated_messages, total_tokens)
# If no summarization needed, return with truncated messages
if not should_summarize:
try:
return handler(request.override(messages=truncated_messages))
except ContextOverflowError:
pass
# Fallback to summarization on context overflow
# Step 3: Perform summarization
cutoff_index = self._determine_cutoff_index(truncated_messages)
if cutoff_index <= 0:
# Can't summarize, return truncated messages
return handler(request.override(messages=truncated_messages))
PatchToolCallsMiddleware作用是修复“悬空”的工具调用,它会先遍历所有tool_calls,看看有没有对应的ToolMessage,如果没有,创建取消的ToolMessage,这能够保证每个tool_call都有对应的ToolMessage,保持消息的一致性,也能够支持Human-in-the-loop中断
class PatchToolCallsMiddleware(AgentMiddleware):
"""Middleware to patch dangling tool calls in the messages history."""
def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, Any] | None: # noqa: ARG002
"""Before the agent runs, handle dangling tool calls from any AIMessage."""
messages = state["messages"]
if not messages or len(messages) == 0:
return None
patched_messages = []
# Iterate over the messages and add any dangling tool calls
for i, msg in enumerate(messages):
patched_messages.append(msg)
if isinstance(msg, AIMessage) and msg.tool_calls:
for tool_call in msg.tool_calls:
corresponding_tool_msg = next(
(msg for msg in messages[i:] if msg.type == "tool" and msg.tool_call_id == tool_call["id"]), # ty: ignore[unresolved-attribute]
None,
)
if corresponding_tool_msg is None:
# We have a dangling tool call which needs a ToolMessage
tool_msg = (
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
"cancelled - another message came in before it could be completed."
)
patched_messages.append(
ToolMessage(
content=tool_msg,
name=tool_call["name"],
tool_call_id=tool_call["id"],
)
)
return {"messages": Overwrite(patched_messages)}
AnthropicPromptCachingMiddleware的作用是利用Anthropic的Prompt Caching功能缓存系统提示词,减少重复发送的token
1.2 create_agent
- 下面是
langchain框架下的create_agent方法初始化逻辑,最后是把整个Agent转化为了langgraph的图结构
def create_agent(...) -> CompiledStateGraph:
# Step 1: 初始化 chat model
if isinstance(model, str):
model = init_chat_model(model)
# Step 2: 转换 system_prompt
system_message: SystemMessage | None = None
if system_prompt is not None:
if isinstance(system_prompt, SystemMessage):
system_message = system_prompt
else:
system_message = SystemMessage(content=system_prompt)
# Step 3: 处理 tools
if tools is None:
tools = []
# Step 4: 处理 response_format
initial_response_format = ...
structured_output_tools = {}
if tool_strategy_for_setup:
for response_schema in tool_strategy_for_setup.schema_specs:
structured_tool_info = OutputToolBinding.from_schema_spec(response_schema)
structured_output_tools[structured_tool_info.tool.name] = structured_tool_info
# Step 5: 收集中间件提供的工具
middleware_tools = [t for m in middleware for t in getattr(m, "tools", [])]
# Step 6: 收集有 wrap_tool_call 钩子的中间件
middleware_w_wrap_tool_call = [
m for m in middleware
if m.__class__.wrap_tool_call is not AgentMiddleware.wrap_tool_call
or m.__class__.awrap_tool_call is not AgentMiddleware.awrap_tool_call
]
# Step 7: 链式组合所有 wrap_tool_call 处理器
wrap_tool_call_wrapper = None
if middleware_w_wrap_tool_call:
wrappers = [
traceable(name=f"{m.name}.wrap_tool_call", process_inputs=_scrub_inputs)(
m.wrap_tool_call
)
for m in middleware_w_wrap_tool_call
]
wrap_tool_call_wrapper = _chain_tool_call_wrappers(wrappers)
# Step 8: 创建 ToolNode
built_in_tools = [t for t in tools if isinstance(t, dict)]
regular_tools = [t for t in tools if not isinstance(t, dict)]
available_tools = middleware_tools + regular_tools
tool_node = (
ToolNode(
tools=available_tools,
wrap_tool_call=wrap_tool_call_wrapper,
awrap_tool_call=awrap_tool_call_wrapper,
)
if available_tools or wrap_tool_call_wrapper or awrap_tool_call_wrapper
else None
)
# Step 9: 验证中间件
if len({m.name for m in middleware}) != len(middleware):
raise AssertionError("Please remove duplicate middleware instances.")
# Step 10: 分类中间件钩子
middleware_w_before_agent = [
m for m in middleware
if m.__class__.before_agent is not AgentMiddleware.before_agent
or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
]
middleware_w_before_model = [
m for m in middleware
if m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model
]
middleware_w_after_model = [
m for m in middleware
if m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model
]
middleware_w_after_agent = [
m for m in middleware
if m.__class__.after_agent is not AgentMiddleware.after_agent
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
]
# Step 11: 链式组合 wrap_model_call 处理器
wrap_model_call_handler = None
if middleware_w_wrap_model_call:
sync_handlers = [
traceable(name=f"{m.name}.wrap_model_call", process_inputs=_scrub_inputs)(
m.wrap_model_call
)
for m in middleware_w_wrap_model_call
]
wrap_model_call_handler = _chain_model_call_handlers(sync_handlers)
# Step 12: 解析 state schema
state_schemas = {m.state_schema for m in middleware}
base_state = state_schema if state_schema is not None else AgentState
state_schemas.add(base_state)
resolved_state_schema = _resolve_schema(state_schemas, "StateSchema", None)
...
# Step 13: 创建 StateGraph
graph = StateGraph(
state_schema=resolved_state_schema,
input_schema=input_schema,
output_schema=output_schema,
context_schema=context_schema,
)
# Step 14: 添加 model 节点
graph.add_node("model", RunnableCallable(model_node, amodel_node, trace=False))
# Step 15: 添加 tools 节点(如果有)
if tool_node is not None:
graph.add_node("tools", tool_node)
# Step 16: 添加中间件节点
for m in middleware:
# 添加 before_agent 节点
if (m.__class__.before_agent is not AgentMiddleware.before_agent
or m.__class__.abefore_agent is not AgentMiddleware.abefore_agent):
sync_before_agent = (
m.before_agent
if m.__class__.before_agent is not AgentMiddleware.before_agent
else None
)
async_before_agent = (
m.abefore_agent
if m.__class__.abefore_agent is not AgentMiddleware.abefore_agent
else None
)
before_agent_node = RunnableCallable(sync_before_agent, async_before_agent, trace=False)
graph.add_node(
f"{m.name}.before_agent",
before_agent_node,
input_schema=resolved_state_schema
)
# 添加 before_model 节点
if (m.__class__.before_model is not AgentMiddleware.before_model
or m.__class__.abefore_model is not AgentMiddleware.abefore_model):
sync_before = (
m.before_model
if m.__class__.before_model is not AgentMiddleware.before_model
else None
)
async_before = (
m.abefore_model
if m.__class__.abefore_model is not AgentMiddleware.abefore_model
else None
)
before_node = RunnableCallable(sync_before, async_before, trace=False)
graph.add_node(
f"{m.name}.before_model",
before_node,
input_schema=resolved_state_schema
)
# 添加 after_model 节点
if (m.__class__.after_model is not AgentMiddleware.after_model
or m.__class__.aafter_model is not AgentMiddleware.aafter_model):
sync_after = (
m.after_model
if m.__class__.after_model is not AgentMiddleware.after_model
else None
)
async_after = (
m.aafter_model
if m.__class__.aafter_model is not AgentMiddleware.aafter_model
else None
)
after_node = RunnableCallable(sync_after, async_after, trace=False)
graph.add_node(
f"{m.name}.after_model",
after_node,
input_schema=resolved_state_schema
)
# 添加 after_agent 节点
if (m.__class__.after_agent is not AgentMiddleware.after_agent
or m.__class__.aafter_agent is not AgentMiddleware.aafter_agent):
sync_after_agent = (
m.after_agent
if m.__class__.after_agent is not AgentMiddleware.after_agent
else None
)
async_after_agent = (
m.aafter_agent
if m.__class__.aafter_agent is not AgentMiddleware.aafter_agent
else None
)
after_agent_node = RunnableCallable(sync_after_agent, async_after_agent, trace=False)
graph.add_node(
f"{m.name}.after_agent",
after_agent_node,
input_schema=resolved_state_schema
)
# Step 17: 确定入口节点
if middleware_w_before_agent:
entry_node = f"{middleware_w_before_agent[0].name}.before_agent"
elif middleware_w_before_model:
entry_node = f"{middleware_w_before_model[0].name}.before_model"
else:
entry_node = "model"
# Step 18: 确定循环入口节点
if middleware_w_before_model:
loop_entry_node = f"{middleware_w_before_model[0].name}.before_model"
else:
loop_entry_node = "model"
# Step 19: 确定循环出口节点
if middleware_w_after_model:
loop_exit_node = f"{middleware_w_after_model[0].name}.after_model"
else:
loop_exit_node = "model"
# Step 20: 确定退出节点
if middleware_w_after_agent:
exit_node = f"{middleware_w_after_agent[-1].name}.after_agent"
else:
exit_node = END
# Step 21: 添加边
graph.add_edge(START, entry_node)
# Step 22: 添加条件边(如果有 tools)
if tool_node is not None:
# tools -> model 的条件边
graph.add_conditional_edges(
"tools",
RunnableCallable(_make_tools_to_model_edge(...), trace=False),
tools_to_model_destinations,
)
# model -> tools 的条件边
graph.add_conditional_edges(
loop_exit_node,
RunnableCallable(_make_model_to_tools_edge(...), trace=False),
model_to_tools_destinations,
)
# Step 23: 添加中间件边
# before_agent 链
if middleware_w_before_agent:
for m1, m2 in itertools.pairwise(middleware_w_before_agent):
_add_middleware_edge(
graph,
name=f"{m1.name}.before_agent",
default_destination=f"{m2.name}.before_agent",
model_destination=loop_entry_node,
end_destination=exit_node,
can_jump_to=_get_can_jump_to(m1, "before_agent"),
)
# Step 24: 编译图
return graph.compile(
checkpointer=checkpointer,
store=store,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
debug=debug,
name=name,
cache=cache,
).with_config(config)
边的构建逻辑:
START → 第一个 before_agent 节点
before_agent 链 → 第一个 before_model 节点
before_model 链 → model 节点
model → 第一个 after_model 节点
after_model 链 → tools 或 before_model(循环)
tools → before_model(循环)
最后一个 after_agent → END
2. Agent执行过程
- 下面是Agent的执行链
用户代码
↓
agent.invoke({"messages": [HumanMessage(...)]})
↓
【LangGraph】CompiledStateGraph.invoke()
↓
【LangGraph】Pregel.invoke()
├─ 初始化执行上下文
├─ 编译图
└─ 执行图
↓
【LangGraph】图执行引擎
↓
START
↓
【LangGraph】遍历节点执行
↓
节点 1: TodoListMiddleware.before_agent
├─ 调用: TodoListMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 2: SkillsMiddleware.before_agent (如果有)
├─ 调用: SkillsMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 3: FilesystemMiddleware.before_agent
├─ 调用: FilesystemMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 4: SubAgentMiddleware.before_agent
├─ 调用: SubAgentMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 5: SummarizationMiddleware.before_agent
├─ 调用: SummarizationMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 6: PatchToolCallsMiddleware.before_agent
├─ 调用: PatchToolCallsMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 7: [用户自定义中间件].before_agent (如果有)
├─ 调用: [用户中间件].before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 8: AnthropicPromptCachingMiddleware.before_agent
├─ 调用: AnthropicPromptCachingMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 9: MemoryMiddleware.before_agent (如果有)
├─ 调用: MemoryMiddleware.before_agent(state, runtime)
└─ 返回: dict | None
↓
节点 10: [实现了 before_model 的中间件].before_model
├─ 调用: [中间件].before_model(state, runtime)
└─ 返回: dict | None
↓
节点 11: model 节点
├─ 调用: model_node(state, runtime)
│ ├─ 创建 ModelRequest
│ ├─ 通过 wrap_model_call_handler 包装
│ │ ├─ AnthropicPromptCachingMiddleware.wrap_model_call()
│ │ └─ [其他实现了 wrap_model_call 的中间件].wrap_model_call()
│ │ ↓
│ │ _execute_model_sync(request)
│ │ ├─ _get_bound_model(request)
│ │ │ ├─ model.bind_tools(final_tools, ...)
│ │ │ └─ 返回 (bound_model, effective_response_format)
│ │ ├─ model_.invoke(messages)
│ │ │ ↓
│ │ │ 【LangChain Core】BaseChatModel.invoke()
│ │ │ ↓
│ │ │ 【LangChain Anthropic】ChatAnthropic._generate()
│ │ │ ↓
│ │ │ 【Anthropic SDK】client.messages.create()
│ │ │ ↓
│ │ │ 【Anthropic API】HTTP 请求
│ │ │ ↓
│ │ │ 返回 AIMessage
│ │ ├─ _handle_model_output(output, effective_response_format)
│ │ └─ 返回 ModelResponse
│ └─ _build_commands(model_response)
└─ 返回: list[Command]
↓
节点 12: [实现了 after_model 的中间件].after_model
├─ 调用: [中间件].after_model(state, runtime)
└─ 返回: dict | None
↓
【LangGraph】条件边判断
├─ _make_model_to_tools_edge(state)
│ ├─ 检查 state.get("jump_to")
│ ├─ 检查 last_ai_message.tool_calls
│ ├─ 检查 pending_tool_calls
│ ├─ 检查 structured_response
│ └─ 返回: "tools" | "model" | "end"
↓
如果返回 "tools":
↓
节点 13: tools 节点
├─ 调用: ToolNode.invoke(state, config)
│ ├─ 遍历 tool_calls
│ │ ├─ wrap_tool_call_wrapper(tool, args, config)
│ │ │ ├─ [实现了 wrap_tool_call 的中间件].wrap_tool_call()
│ │ │ └─ tool.invoke(args, config)
│ │ │ ↓
│ │ │ 工具具体实现
│ │ │ ↓
│ │ │ 返回工具输出
│ │ └─ 创建 ToolMessage
│ └─ 返回: {"messages": [ToolMessage(...)]}
↓
【LangGraph】条件边判断
├─ _make_tools_to_model_edge(state)
├─ 返回: "model" (继续循环)
↓
回到节点 10 (before_model)
↓
如果返回 "model":
↓
回到节点 10 (before_model)
↓
如果返回 "end":
↓
节点 14: [实现了 after_agent 的中间件].after_agent
├─ 调用: [中间件].after_agent(state, runtime)
└─ 返回: dict | None
↓
END
↓
返回最终状态
↓
返回 result
↓
用户代码
2.1 Tool工作机制
- Tool有两个来源,一是用户直接传入的tools,第二是
middleware通过tools属性提供的工具,例如我们刚刚说的todolist工具 - 在上面所说的Agent执行链路中,langgraph框架会调用条件边函数来判断是否该走这个节点,如下面所示
def _route(
self,
input: Any, # 当前节点的输出
config: RunnableConfig,
*,
reader: Callable[[RunnableConfig], Any] | None,
writer: _Writer,
) -> Runnable:
# 1. 读取当前状态
if reader:
value = reader(config)
# 合并状态
if isinstance(value, dict) and isinstance(input, dict):
value = {**input, **value}
else:
value = input
# 2. 调用条件边函数(这里就是 _make_model_to_tools_edge 返回的函数)
result = self.path.invoke(value, config)
# 3. 处理结果并写入下一步的节点
return self._finish(writer, input, result, config)
当LLM判断需要执行tools的时候,返回的数据里面,tool_calls会给一个tool,然后框架拿到之后再去执行这个tool对应的方法,如下所示
def model_to_tools(state: dict[str, Any]) -> str | list[Send] | None:
# 条件 1:优先检查 jump_to(最高优先级)
if jump_to := state.get("jump_to"):
return _resolve_jump(...)
# 条件 2:无 AI 消息则退出
if last_ai_message is None:
return end_destination
# 条件 3:模型未调用工具则退出
if len(last_ai_message.tool_calls) == 0:
return end_destination
# 条件 4:有待执行的工具调用 → 并发执行
if pending_tool_calls:
return [Send(...) for tool_call in pending_tool_calls]
# 条件 5:有结构化响应则退出
if "structured_response" in state:
return end_destination
# 条件 6:默认回到模型节点
return model_destination
2.2 Skill工作机制
SKILL规范见官方文档https://agentskills.io/specification,SKILL.md文件必须包含yaml frontmatter,随后是markdown内容
The SKILL.md file must contain YAML frontmatter followed by Markdown content.
类似下面这种格式
---
name: web-research
description: Structured approach to conducting thorough web research
license: MIT
---
# Web Research Skill
## When to Use
- User asks you to research a topic
...
- 然后skill的加载过程如下
# libs/deepagents/deepagents/middleware/skills.py文件
def _list_skills(backend: BackendProtocol, source_path: str) -> list[SkillMetadata]:
"""从 backend source 加载所有 skills"""
skills: list[SkillMetadata] = []
# Step 1: 列出 source_path 下的所有目录
ls_result = backend.ls(source_path)
items = ls_result.entries if isinstance(ls_result, LsResult) else ls_result
# Step 2: 找到所有包含 SKILL.md 的目录
skill_dirs = []
for item in items or []:
if not item.get("is_dir"):
continue
skill_dirs.append(item["path"])
# Step 3: 下载所有 SKILL.md 文件
skill_md_paths = []
for skill_dir_path in skill_dirs:
skill_dir = PurePosixPath(skill_dir_path)
skill_md_path = str(skill_dir / "SKILL.md")
skill_md_paths.append((skill_dir_path, skill_md_path))
paths_to_download = [skill_md_path for _, skill_md_path in skill_md_paths]
responses = backend.download_files(paths_to_download)
# Step 4: 解析每个 SKILL.md 的 YAML frontmatter
for (skill_dir_path, skill_md_path), response in zip(skill_md_paths, responses, strict=True):
if response.error:
continue
content = response.content.decode("utf-8")
directory_name = PurePosixPath(skill_dir_path).name
# 解析 metadata
skill_metadata = _parse_skill_metadata(
content=content,
skill_path=skill_md_path,
directory_name=directory_name,
)
if skill_metadata:
skills.append(skill_metadata)
return skills
def _parse_skill_metadata(
content: str,
skill_path: str,
directory_name: str,
) -> SkillMetadata | None:
"""解析 SKILL.md 的 YAML frontmatter"""
# 匹配 --- 之间的 YAML frontmatter
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n"
match = re.match(frontmatter_pattern, content, re.DOTALL)
if not match:
return None
frontmatter_str = match.group(1)
# 解析 YAML
frontmatter_data = yaml.safe_load(frontmatter_str)
# 提取 metadata
name = str(frontmatter_data.get("name", "")).strip()
description = str(frontmatter_data.get("description", "")).strip()
license = str(frontmatter_data.get("license", "")).strip() or None
compatibility = str(frontmatter_data.get("compatibility", "")).strip() or None
allowed_tools = ...
return SkillMetadata(
name=name,
description=description,
path=skill_path,
license=license,
compatibility=compatibility,
allowed_tools=allowed_tools,
metadata=...,
)
- 可以看到,SKILL.md文件内容里面,前缀的元数据是通过正则表达式取出来的。接下来我们看看skill是如何注入到system_prompt里面的
# libs/deepagents/deepagents/middleware/skills.py
...
class SkillsMiddleware(AgentMiddleware[Any, ContextT, ResponseT]):
"""Skills middleware for loading and exposing agent skills to the system prompt."""
def __init__(
self,
backend: BackendProtocol | BackendFactory,
sources: list[str],
) -> None:
"""初始化 SkillsMiddleware"""
super().__init__()
self.backend = backend
self.sources = sources
def before_agent(
self,
state: AgentState[ResponseT],
runtime: Runtime[ContextT],
) -> dict[str, Any] | None:
"""在 agent 执行前加载 skills 并注入到 system prompt"""
# Step 1: 解析 backend
resolved_backend = (
self.backend(runtime) if callable(self.backend) else self.backend
)
# Step 2: 从所有 sources 加载 skills
all_skills: list[SkillMetadata] = []
for source in self.sources:
skills = _list_skills(resolved_backend, source)
all_skills.extend(skills)
# Step 3: 构建 skills 列表字符串
skills_list = "\n".join(
f"- **{skill['name']}**: {skill['description']}"
for skill in all_skills
)
skills_locations = "\n".join(
f"- {skill['name']}: {skill['path']}"
for skill in all_skills
)
# Step 4: 构建 system prompt
skills_system_prompt = SKILLS_SYSTEM_PROMPT.format(
skills_locations=skills_locations,
skills_list=skills_list,
)
# Step 5: 返回状态更新
return {
"skills_metadata": all_skills,
"system_message": skills_system_prompt,
}
# 下面是SKILL的系统提示词,其中skills_locations和skills_list会在agent执行之前注入
SKILLS_SYSTEM_PROMPT = """
## Skills System
You have access to a skills library that provides specialized capabilities and domain knowledge.
{skills_locations}
**Available Skills:**
{skills_list}
**How to Use Skills (Progressive Disclosure):**
Skills follow a **progressive disclosure** pattern - you see their name and description above, but only read full instructions when needed:
1. **Recognize when a skill applies**: Check if the user's task matches a skill's description
2. **Read the skill's full instructions**: Use the path shown in the skill list above
3. **Follow the skill's instructions**: SKILL.md contains step-by-step workflows, best practices, and examples
4. **Access supporting files**: Skills may include helper scripts, configs, or reference docs - use absolute paths
**When to Use Skills:**
- User's request matches a skill's domain (e.g., "research X" -> web-research skill)
- You need specialized knowledge or structured workflows
- A skill provides proven patterns for complex tasks
**Executing Skill Scripts:**
Skills may contain Python scripts or other executable files. Always use absolute paths from the skill list.
**Example Workflow:**
User: "Can you research the latest developments in quantum computing?"
1. Check available skills -> See "web-research" skill with its path
2. Read the skill using the path shown
3. Follow the skill's research workflow (search -> organize -> synthesize)
4. Use any helper scripts with absolute paths
Remember: Skills make you more capable and consistent. When in doubt, check if a skill exists for the task!
"""
Skill注入流程总结如下,这种模型根据需求读取SKILL内容的方式,就是渐进式披露(Progressive Disclosure),Skill本身其实就是一个渐进式披露的文件系统,可以动态提供给模型知识
1. SkillsMiddleware.before_agent 被调用
↓
2. 从 backend 的 sources 加载所有 skills
├─ 列出目录
├─ 下载 SKILL.md
└─ 解析 YAML frontmatter
↓
3. 构建 skills 列表字符串
↓
4. 将 skills 信息注入到 system prompt
↓
5. 模型可以看到 skills 的描述和路径
↓
6. 模型可以根据 skill 的描述决定是否使用
↓
7. 如果使用,模型会读取 SKILL.md 的完整内容
2.3 SubAgent工作机制
- 我们看核心的
SubAgentMiddleware类,它会把subAgent转化为task_tool,实际上,在langchain系列框架中,我们可以把subAgent视为一种特殊的tool,叫做task_tool
class SubAgentMiddleware(AgentMiddleware[Any, ContextT, ResponseT]):
def __init__(
self,
*,
backend: BackendProtocol | BackendFactory | None = None,
subagents: Sequence[SubAgent | CompiledSubAgent] | None = None,
system_prompt: str | None = TASK_SYSTEM_PROMPT,
task_description: str | None = None,
**deprecated_kwargs: Unpack[_DeprecatedKwargs],
) -> None:
...
task_tool = _build_task_tool(subagent_specs, task_description)
# Build system prompt with available agents
if system_prompt and subagent_specs:
agents_desc = "\n".join(f"- {s['name']}: {s['description']}" for s in subagent_specs)
self.system_prompt = system_prompt + "\n\nAvailable subagent types:\n" + agents_desc
else:
self.system_prompt = system_prompt
self.tools = [task_tool]
def _build_task_tool( # noqa: C901
subagents: list[_SubagentSpec],
task_description: str | None = None,
) -> BaseTool:
...
def task(
description: Annotated[
str,
"A detailed description of the task for the subagent to perform autonomously. Include all necessary context and specify the expected output format.", # noqa: E501
],
subagent_type: Annotated[str, "The type of subagent to use. Must be one of the available agent types listed in the tool description."],
runtime: ToolRuntime,
) -> str | Command:
if subagent_type not in subagent_graphs:
allowed_types = ", ".join([f"`{k}`" for k in subagent_graphs])
return f"We cannot invoke subagent {subagent_type} because it does not exist, the only allowed types are {allowed_types}"
if not runtime.tool_call_id:
value_error_msg = "Tool call ID is required for subagent invocation"
raise ValueError(value_error_msg)
# 使用用户定义的的runnable对象
subagent, subagent_state = _validate_and_prepare_state(subagent_type, description, runtime)
result = subagent.invoke(subagent_state)
return _return_command_with_state_update(result, runtime.tool_call_id)
...
# 实际上最后构造的是一个可调用的function,调用的函数就是task
return StructuredTool.from_function(
name="task",
func=task,
coroutine=atask,
description=description,
)
- 所以这里的逻辑实际上就是主Agent的LLM触发一次
function call,选择某个task_tool,然后这个task_tool内部再去执行子Agent的逻辑,也是一个图编排,跟主Agent类似
3. checkpointer
- 字面意思是检查点,是LangGraph中用于控制子图检查点行为的类型定义,下面是代码注释。意思是如果提供这个实例,他就会用于处理这个graph的短期记忆;如果设置为None的话,就会继承父
graph的checkpointer,假设设置成False,意思就是不使用checkpointer,框架也就不处理短期记忆了。如果需要框架处理,就必须在配置里面增加一个thread_id用于会话隔离
checkpointer: A checkpoint saver object or flag.
If provided, this `Checkpointer` serves as a fully versioned "short-term memory" for the graph,
allowing it to be paused, resumed, and replayed from any point.
If `None`, it may inherit the parent graph's checkpointer when used as a subgraph.
If `False`, it will not use or inherit any checkpointer.
**Important**: When a checkpointer is enabled, you should pass a `thread_id`
in the config when invoking the graph:
Checkpointer = None | bool | BaseCheckpointSaver
checkpointer本质上是一个状态持久化存储系统,实现了完整的CURD操作,现在LangGraph提供了下面几种检查点的实现
InMemorySaver:内存存储(测试用)PostgresSaver:PostgreSQL数据库SqliteSaver:SQLite数据库RedisSaver:Redis数据库
源码链路分析
这个代码在langgraph里面,我们从上层开始分析
invoke实际是封装了stream,代码如下所示,通过stream_mode来控制的
# libs/langgraph/langgraph/pregel/main.py:3235
def invoke(
self,
input: InputT | Command | None,
config: RunnableConfig | None = None,
*,
context: ContextT | None = None,
stream_mode: StreamMode = "values",
print_mode: StreamMode | Sequence[StreamMode] = (),
output_keys: str | Sequence[str] | None = None,
interrupt_before: All | Sequence[str] | None = None,
interrupt_after: All | Sequence[str] | None = None,
durability: Durability | None = None,
version: Literal["v1", "v2"] = "v1",
**kwargs: Any,
) -> dict[str, Any] | Any:
if version == "v2":
# v2: values stream parts carry interrupts directly
for chunk in self.stream(
input,
config,
context=context,
stream_mode="values" if stream_mode == "values" else stream_mode,
print_mode=print_mode,
output_keys=output_keys,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
durability=durability,
version=version,
**kwargs,
):
if stream_mode == "values":
latest = chunk["data"]
if chunk_ints := chunk.get("interrupts", ()):
interrupts.extend(chunk_ints) # type: ignore[arg-type]
else:
chunks.append(chunk)
else:
# v1: collect interrupts from updates stream
for chunk in self.stream(
input,
config,
context=context,
stream_mode=(
["updates", "values"] if stream_mode == "values" else stream_mode
),
print_mode=print_mode,
output_keys=output_keys,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
durability=durability,
**kwargs,
):
if stream_mode == "values":
if len(chunk) == 2:
mode, payload = cast(tuple[StreamMode, Any], chunk)
else:
_, mode, payload = cast(
tuple[tuple[str, ...], StreamMode, Any], chunk
)
if (
mode == "updates"
and isinstance(payload, dict)
and (ints := payload.get(INTERRUPT)) is not None
):
interrupts.extend(ints)
elif mode == "values":
latest = payload
else:
chunks.append(chunk)
def stream(
self,
input: InputT | Command | None,
config: RunnableConfig | None = None,
*,
context: ContextT | None = None,
stream_mode: StreamMode | Sequence[StreamMode] | None = None,
print_mode: StreamMode | Sequence[StreamMode] = (),
output_keys: str | Sequence[str] | None = None,
interrupt_before: All | Sequence[str] | None = None,
interrupt_after: All | Sequence[str] | None = None,
durability: Durability | None = None,
subgraphs: bool = False,
debug: bool | None = None,
version: Literal["v1", "v2"] = "v1",
**kwargs: Unpack[DeprecatedKwargs],
) -> Iterator[dict[str, Any] | Any]:
(
stream_modes,
output_keys,
interrupt_before_,
interrupt_after_,
checkpointer, # ← 关键:获取 checkpointer
store,
cache,
durability_,
) = self._defaults(
config,
stream_mode=stream_mode,
print_mode=print_mode,
output_keys=output_keys,
interrupt_before=interrupt_before,
interrupt_after=interrupt_after,
durability=durability,
)
# 先从config中提取thread_id
# 如果有thread_id且配置了checkpointer,则启用状态持久化
# 返回checkpointer实例
...
# 创建核心的执行循环
with SyncPregelLoop(
input,
stream=StreamProtocol(stream.put, stream_modes),
config=config,
store=store,
cache=cache,
checkpointer=checkpointer, # ← 传入 checkpointer
nodes=self.nodes,
specs=self.channels,
output_keys=output_keys,
input_keys=self.input_channels,
stream_keys=self.stream_channels_asis,
interrupt_before=interrupt_before_,
interrupt_after=interrupt_after_,
manager=run_manager,
durability=durability_,
trigger_to_nodes=self.trigger_to_nodes,
migrate_checkpoint=self._migrate_checkpoint,
retry_policy=self.retry_policy,
cache_policy=self.cache_policy,
) as loop:
while loop.tick(): # 每一步的触发
for task in loop.match_cached_writes():
loop.output_writes(task.id, task.writes, cached=True)
for _ in runner.tick(
[t for t in loop.tasks.values() if not t.writes],
timeout=self.step_timeout,
get_waiter=get_waiter,
schedule_task=loop.accept_push,
):
# emit output
yield from _output(
stream_mode,
print_mode,
subgraphs,
stream.get,
queue.Empty,
version,
_output_mapper,
_state_mapper,
)
loop.after_tick()
...
def tick(self) -> bool:
"""Execute a single iteration of the Pregel loop."""
# 检查迭代限制
if self.step > self.stop:
self.status = "out_of_steps"
return False
# 准备下一个任务
self.tasks = prepare_next_tasks(
self.checkpoint, # ← 使用当前 checkpoint
self.checkpoint_pending_writes,
self.nodes,
self.channels,
self.managed,
self.config,
self.step,
self.stop,
for_execution=True,
manager=self.manager,
store=self.store,
checkpointer=self.checkpointer, # ← 传入 checkpointer
trigger_to_nodes=self.trigger_to_nodes,
updated_channels=self.updated_channels,
retry_policy=self.retry_policy,
cache_policy=self.cache_policy,
)
# 如果没有更多任务,完成
if not self.tasks:
self.status = "done"
return False
# 应用待处理的写入
if not self.is_replaying and self.checkpoint_pending_writes:
self._match_writes(self.tasks)
# 检查是否应该中断
if self.interrupt_before and should_interrupt(...):
self.status = "interrupt_before"
raise GraphInterrupt()
return True
- 下面这个
get_tuple的逻辑就是从历史数据里面获取同一个thread_id对应的历史记录,然后把它们加载到历史消息中,作为短期记忆
def __enter__(self) -> Self:
if not self.checkpointer:
saved = None
elif self.checkpoint_config[CONF].get(CONFIG_KEY_CHECKPOINT_ID):
# Explicit checkpoint_id requested — fetch that exact checkpoint.
# This covers both normal replay and subgraphs resolved via
# checkpoint_map during time-travel.
saved = self.checkpointer.get_tuple(self.checkpoint_config)
elif replay_state := self.config[CONF].get(CONFIG_KEY_REPLAY_STATE):
# Subgraph replay: the parent is replaying and passed us a
# replay_state with its checkpoint_id. Look up our checkpoint
# from the parent's checkpoint_map instead of fetching latest.
saved = replay_state.get_checkpoint(
self.config[CONF].get(CONFIG_KEY_CHECKPOINT_NS, ""),
self.checkpointer,
self.checkpoint_config,
)
# Clear RESUMING so _first re-applies input instead of resuming.
# This recreates ephemeral routing channels so nodes trigger
# naturally via version comparison.
self.config[CONF].pop(CONFIG_KEY_RESUMING, None)
else:
# Normal case: fetch the most recent checkpoint for this
# graph/thread. Returns None on first invocation.
saved = self.checkpointer.get_tuple(self.checkpoint_config)
...
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
"""Get a checkpoint tuple from the database asynchronously.
This method retrieves a checkpoint tuple from the Postgres database based on the
provided config. If the config contains a `checkpoint_id` key, the checkpoint with
the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint
for the given thread ID is retrieved.
Args:
config: The config to use for retrieving the checkpoint.
Returns:
The retrieved checkpoint tuple, or None if no matching checkpoint was found.
"""
thread_id = config["configurable"]["thread_id"]
checkpoint_id = get_checkpoint_id(config)
checkpoint_ns = config["configurable"].get("checkpoint_ns", "")
if checkpoint_id:
args: tuple[Any, ...] = (thread_id, checkpoint_ns, checkpoint_id)
where = "WHERE thread_id = %s AND checkpoint_ns = %s AND checkpoint_id = %s"
else:
args = (thread_id, checkpoint_ns)
where = "WHERE thread_id = %s AND checkpoint_ns = %s ORDER BY checkpoint_id DESC LIMIT 1"
async with self._cursor() as cur:
await cur.execute(
self.SELECT_SQL + where,
args,
binary=True,
)
value = await cur.fetchone()
if value is None:
return None
# migrate pending sends if necessary
if value["checkpoint"]["v"] < 4 and value["parent_checkpoint_id"]:
await cur.execute(
self.SELECT_PENDING_SENDS_SQL,
(thread_id, [value["parent_checkpoint_id"]]),
)
if sends := await cur.fetchone():
if value["channel_values"] is None:
value["channel_values"] = []
self._migrate_pending_sends(
sends["sends"],
value["checkpoint"],
value["channel_values"],
)
return await self._load_checkpoint_tuple(value)
def after_tick(self) -> None:
# 完成超级步骤
writes = [w for t in self.tasks.values() for w in t.writes]
# 应用所有写入到 checkpoint
self.updated_channels = apply_writes(
self.checkpoint,
self.channels,
self.tasks.values(),
self.checkpointer_get_next_version,
self.trigger_to_nodes,
)
# 清空待处理的写入
self.checkpoint_pending_writes.clear()
# 只在第一次 tick 时重放已完成的任务
self.is_replaying = False
# 保存 checkpoint ← 关键!
self._put_checkpoint({"source": "loop"})
# 执行后检查是否应该中断
if self.interrupt_after and should_interrupt(...):
self.status = "interrupt_after"
raise GraphInterrupt()
# 取消恢复标志
self.config[CONF].pop(CONFIG_KEY_RESUMING, None)
完整执行流程
用户调用 graph.invoke(input, config)
↓
1. invoke() 调用 stream()
↓
2. stream() 初始化:
- 从 config 获取 thread_id
- 获取 checkpointer 实例
- 创建 SyncPregelLoop
↓
3. SyncPregelLoop 初始化:
- 如果有 thread_id,调用 checkpointer.get_tuple() 获取历史 checkpoint
- 如果没有历史,创建空 checkpoint
- 从 checkpoint 恢复 channels 状态
↓
4. 执行循环 while loop.tick():
a. tick(): 准备任务
- 使用当前 checkpoint 状态
- 决定执行哪些任务
- 应用待处理的写入
- 检查中断条件
↓
b. runner.tick(): 执行任务
- 运行节点函数
- 收集写入
↓
c. 流式传输输出
- 根据流模式发射事件
↓
d. after_tick():
- 应用写入到 checkpoint
- 调用 _put_checkpoint() 保存状态
- 检查中断条件
↓
5. 循环结束,返回最终输出
- 通过上面的
checkpointer,通过框架能够实现一种效果,后一个thread_id的消息能够继承上一个thread_id的消息,换句话说,checkpointer保存了旧状态的快照。像下面这样
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
# 创建图
graph = StateGraph(state)
# ... 添加节点和边 ...
# 创建 checkpointer
checkpointer = MemorySaver()
app = graph.compile(checkpointer=checkpointer)
# 第一次对话
result1 = app.invoke(
{"messages": ["你好"]},
config={"configurable": {"thread_id": "sync_thread_1"}}
)
# 返回: {"answer": "你好!作为通义千问(Qwen)..."}
# 第二次对话 - 自动恢复历史
result2 = app.invoke(
{"messages": ["你会什么?"]},
config={"configurable": {"thread_id": "sync_thread_1"}}
)
# 系统自动:
# 1. 查询 thread_id = "sync_thread_1" 的历史 checkpoint
# 2. 加载历史状态(包括之前的 answer、intent 等)
# 3. 在历史基础上处理新消息
# 4. 返回新的回复
上下文超限怎么办
- 框架内置了
SummarizationMiddleware这个middleware来处理超context的情况,当发现工作上下文超过了阈值,就会触发summary,等下次交互的时候,就把summary和最新的记录拼接到消息中,之前的长消息就不发送了
阈值代码如下
def compute_summarization_defaults(model: BaseChatModel) -> SummarizationDefaults:
"""Compute default summarization settings based on model profile.
Args:
model: A resolved chat model instance.
Returns:
Default settings for trigger, keep, and truncate_args_settings.
If the model has a profile with `max_input_tokens`, uses
fraction-based settings. Otherwise, uses fixed token/message counts.
"""
has_profile = ( # 检测模型是否提供这个能力
model.profile is not None
and isinstance(model.profile, dict)
and "max_input_tokens" in model.profile
and isinstance(model.profile["max_input_tokens"], int)
)
if has_profile:
return {
"trigger": ("fraction", 0.85), # 达到85%的时候触发summary
"keep": ("fraction", 0.10), # 保留10%
"truncate_args_settings": {
"trigger": ("fraction", 0.85),
"keep": ("fraction", 0.10),
},
}
# Defaults for models without profile info are more conservative to avoid
# overshooting context limits.
return {
"trigger": ("tokens", 170000), # token达到这个数的时候summary
"keep": ("messages", 6),
"truncate_args_settings": {
"trigger": ("messages", 20),
"keep": ("messages", 20),
},
}
下面是触发summary事件的代码
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelResponse | ExtendedModelResponse:
"""Process messages before model invocation, with history offloading and arg truncation.
First applies any previous summarization events to reconstruct the effective message list.
Then truncates large tool arguments in old messages if configured.
Finally offloads messages to backend before summarization if thresholds are met.
Control flow details:
- If thresholds say "do not summarize", we still attempt one normal
model call with the current effective/truncated messages.
- If that call raises `ContextOverflowError`, we immediately fall back to
the summarization path and retry the model call with
`summary_message + preserved_recent_messages`.
Unlike the legacy `before_model` approach, this does NOT modify the LangGraph state.
Instead, it tracks summarization events in middleware state and modifies the model
request directly.
Args:
request: The model request to process.
handler: The handler to call with the (possibly modified) request.
Returns:
A plain `ModelResponse` when no summarization event is created, or
an `ExtendedModelResponse` that updates `_summarization_event`
with `cutoff_index`, `summary_message`, and `file_path`.
If `cutoff_index <= 0`, no compaction occurs and no
`_summarization_event` update is emitted.
"""
# Get effective messages based on previous summarization events
effective_messages = self._get_effective_messages(request)
# Step 1: Truncate args if configured
truncated_messages, _ = self._truncate_args(
effective_messages,
request.system_message,
request.tools,
)
# Step 2: Check if summarization should happen
counted_messages = [request.system_message, *truncated_messages] if request.system_message is not None else truncated_messages
try:
total_tokens = self.token_counter(counted_messages, tools=request.tools) # ty: ignore[unknown-argument]
except TypeError:
total_tokens = self.token_counter(counted_messages)
should_summarize = self._should_summarize(truncated_messages, total_tokens)
# If no summarization needed, return with truncated messages
if not should_summarize:
try:
return handler(request.override(messages=truncated_messages))
except ContextOverflowError:
pass
# Fallback to summarization on context overflow
# Step 3: Perform summarization
cutoff_index = self._determine_cutoff_index(truncated_messages)
if cutoff_index <= 0:
# Can't summarize, return truncated messages
return handler(request.override(messages=truncated_messages))
messages_to_summarize, preserved_messages = self._partition_messages(truncated_messages, cutoff_index)
# Offload to backend first so history is preserved before summarization.
# If offload fails, summarization still proceeds (with file_path=None).
backend = self._get_backend(request.state, request.runtime)
file_path = self._offload_to_backend(backend, messages_to_summarize)
if file_path is None:
msg = "Offloading conversation history to backend failed during summarization. Older messages will not be recoverable."
logger.error(msg)
warnings.warn(msg, stacklevel=2)
# Generate summary
summary = self._create_summary(messages_to_summarize)
# Build summary message with file path reference
new_messages = self._build_new_messages_with_path(summary, file_path)
previous_event = request.state.get("_summarization_event")
state_cutoff_index = self._compute_state_cutoff(previous_event, cutoff_index)
# Create new summarization event
new_event: SummarizationEvent = {
"cutoff_index": state_cutoff_index,
"summary_message": new_messages[0], # The HumanMessage with summary # ty: ignore[invalid-argument-type]
"file_path": file_path,
}
# Modify request to use summarized messages
modified_messages = [*new_messages, *preserved_messages]
response = handler(request.override(messages=modified_messages))
# Return ExtendedModelResponse with state update
return ExtendedModelResponse(
model_response=response,
command=Command(update={"_summarization_event": new_event}),
)
更多推荐




所有评论(0)