0%

agentscope 源码分析(6):Agent 的元类与基类

从本篇文章开始,我们将正式深入剖析 AgentScope 框架中 Agent 模块的核心实现源码,首先聚焦于 Agent 底层的基础设计,逐一解析 StateModule_AgentMeta 元类与 AgentBase 基类等关键组件。

Agent 模块

AgentScope 的 agent 模块是整个框架的核心,定义了 Agent 的基础架构、行为模式和扩展能力。模块采用分层设计,从基类到具体实现,构建了一套完整的 Agent 开发体系。Agent 模块的代码位于 src/agentscope/agent/ 目录下,文件列表如下:

1
2
3
4
5
6
7
8
9
10
11
src/agentscope/agent/
├── __init__.py # 模块导出
├── _utils.py # 工具函数
├── _agent_meta.py # 元类(Hook 机制)
├── _agent_base.py # Agent 基类
├── _react_agent_base.py # ReAct Agent 基类
├── _react_agent.py # ReAct Agent 完整实现
├── _user_agent.py # 用户 Agent
├── _user_input.py # 用户输入处理
├── _a2a_agent.py # A2A 协议 Agent
└── _realtime_agent.py # 实时 Agent

AgentBase 是所有 Agent 的基类,定义了所有 Agent 的公共接口和属性:

1
2
class AgentBase(StateModule, metaclass=_AgentMeta):
......

可以看到 AgentBase 又继承自 StateModule 类,其实不仅 AgentBase 继承自 StateModule,之前介绍的 ToolkitPlanNotebookPlanStorageBaseLongTermMemoryBase 等类型都继承自 StateModule,那这个 StateModule 类到底起什么作用呢?我们详细介绍一下。

StateModule

在构建 AI Agent 系统时,存在一个核心需求:状态持久化。典型场景:

  • Agent 执行长任务时崩溃,需要恢复现场
  • Agent 需要在多次对话间保持记忆
  • Agent 需要保存/加载历史会话

StateModule 提供了一套声明式状态管理机制,提供统一的序列化/反序列化机制,支持:

  • 自动状态追踪:自动追踪继承自 StateModule 的子属性
  • 手动状态注册:支持注册普通属性并提供自定义序列化方法
  • 嵌套状态管理:支持多层嵌套对象的状态序列化和恢复
  • Session 管理:与 Session 模块配合实现应用级状态持久化

StateModule 模块的代码位于 src/agentscope/module/_state_module.py。AgentScope 中,以下类型都继承自 StateModule:

1
2
3
4
5
6
7
8
9
StateModule (基类)

├── AgentBase # 所有 Agent 的基类
├── MemoryBase # 内存记忆基类
├── LongTermMemoryBase # 长期记忆基类
├── Toolkit # 工具集
├── PlanNotebook # 计划笔记本
├── PlanStorageBase # 计划存储基类
└── RealtimeAgent # 实时 Agent

StateModule 用法示例

首先我们将通过一些简单例子来展示 StateModule 的用法,以对 StateModule 有个直观的认识。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from collections import OrderedDict
from agentscope.module import StateModule


# ========== 1. 必须注册才能追踪 ==========
class Counter(StateModule):
def __init__(self):
super().__init__()
self.count = 0
self.temp = "not tracked" # 未注册,不会被追踪
self.register_state("count") # 只有 count 被注册


c = Counter()
c.count = 100
c.temp = "new value"
state = c.state_dict()

new_c = Counter()
new_c.load_state_dict(state)
print(f"【1】state={state}")
print(f" count={new_c.count} (registered, restored)")
print(f" temp='{new_c.temp}' (not registered, NOT restored)")


# ========== 2. 子模块自动追踪 ==========
class Memory(StateModule):
def __init__(self):
super().__init__()
self.msgs = []
self.register_state("msgs")


class Agent(StateModule):
def __init__(self):
super().__init__()
self.memory = Memory() # 自动追踪,无需 register_state


agent = Agent()
agent.memory.msgs.append("hello")
state = agent.state_dict()

new_agent = Agent()
new_agent.load_state_dict(state)
print(f"【2】state={state}")
print(f" restored: msgs={new_agent.memory.msgs}")


# ========== 3. 复杂类型自定义序列化 ==========
class User(StateModule):
def __init__(self):
super().__init__()
self.prefs = OrderedDict()
self.register_state(
"prefs",
custom_to_json=lambda x: dict(x),
custom_from_json=lambda x: OrderedDict(x),
)


user = User()
user.prefs["lang"] = "zh"
state = user.state_dict()

new_user = User()
new_user.load_state_dict(state)
print(f"【3】state={state}")
print(
f" restored: prefs={dict(new_user.prefs)}, type={type(new_user.prefs).__name__}"
)
1
2
3
4
5
6
7
8
# python demo_statemodule_basics.py
【1】state={'count': 100}
count=100 (registered, restored)
temp='not tracked' (not registered, NOT restored)
【2】state={'memory': {'msgs': ['hello']}}
restored: msgs=['hello']
【3】state={'prefs': {'lang': 'zh'}}
restored: prefs={'lang': 'zh'}, type=OrderedDict
  • 对于继承自 StateModule 的类,通过 state_dict() 获取该对象的所有 state 信息(序列化)、而通过 load_state_dict 从 state 信息中恢复该对象的状态(反序列化)
  • 在第一个例子中,对于 Counter 类,想要将 counter 属性作为一个可持久化的状态,必须通过 register_state("counter"),而 temp 属性没有被注册为状态,因此不会被状态跟踪
  • 对于第二个例子,对于 Agent 类型,其继承自 StateModule 类,因此可以对 Agent 类型进行状态跟踪。而其内部的 memory 成员,由于其类型 Memory 本身继承自 StateModule,因此 memory 自动就会被状态跟踪,无需再调用 register_state
  • 第三个例子则展示了,我们可以在 register_state() 中自定义序列化和反序列化方法

这个例子展示了我们所说的 StateModule 的两个关键能力:即

  • 自动状态追踪:自动追踪继承自 StateModule 的子属性
  • 手动状态注册:手动通过 register_state() 注册普通属性,并可以提供自定义序列化方法

接下来我们这个例子,则展示了 StateModule嵌套状态管理支持多层嵌套对象的状态序列化和恢复

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from agentscope.module import StateModule


class ToolHistory(StateModule):
def __init__(self):
super().__init__()
self.calls = []
self.register_state("calls")


class ToolKit(StateModule):
def __init__(self):
super().__init__()
self.history = ToolHistory() # auto tracked


class Agent(StateModule):
def __init__(self, name: str):
super().__init__()
self.name = name
self.register_state("name")
self.toolkit = ToolKit() # auto tracked

agent = Agent("Assistant")
agent.toolkit.history.calls.append({"tool": "search", "args": {"q": "test"}})

new_agent = Agent("temp")
new_agent.load_state_dict(agent.state_dict())

print(f"【Nested】state={state}")

print(f" name={new_agent.name}")
print(f" calls={new_agent.toolkit.history.calls}")
1
2
3
4
5
# python demo_statemodule_nested.py
【Nested】state={'toolkit': {'history': {'calls': [{'tool': 'search', 'args': {'q': 'test'}}]}}, 'name': 'Assistant'}
name=Assistant
calls=[{'tool': 'search', 'args': {'q': 'test'}}]

StateModule 的实现

接下来我们再来看 StateModule 是如何实现的。StateModule 的类型定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
class StateModule:
def __init__(self):
self._module_dict = OrderedDict() # 存储子 StateModule 属性
self._attribute_dict = OrderedDict() # 存储注册的普通属性

def __setattr__(self, key: str, value: Any) -> None:
if isinstance(value, StateModule):
# 自动将 StateModule 类型的属性注册到 _module_dict
self._module_dict[key] = value
super().__setattr__(key, value)


  • 通过 _module_dict 来保存所有本身是 StateModule 的属性,而通过 _attribute_dict 来保存所有手动注册的普通属性
  • 上文说过,对于本身就是 StateModule 类型的属性会自动进行状态跟踪,就是依靠 __setattr__ 方法,在设置属性时,如果发现是 StateModule 类型,则将其添加到 _module_dict

register_state() 方法用于手动状态注册:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
def register_state(
self,
attr_name: str,
custom_to_json: Callable[[Any], JSONSerializableObject] | None = None,
custom_from_json: Callable[[JSONSerializableObject], Any]
| None = None,
) -> None:
attr = getattr(self, attr_name)

# 判断是否原生可以 json 序列化
if custom_to_json is None:
# Make sure the attribute is JSON serializable natively
try:
json.dumps(attr)
except Exception as e:
raise TypeError(
f"Attribute '{attr_name}' is not JSON serializable. "
"Please provide a custom function to convert the "
"attribute to a JSON-serializable format.",
) from e

if attr_name in self._module_dict:
raise ValueError(
f"Attribute `{attr_name}` is already registered as a module. ",
)

# 添加到注册的属性字典
# 记录下所提供的 json 序列化/反序列化函数,如果没有提供,则为 None
self._attribute_dict[attr_name] = _JSONSerializeFunction(
to_json=custom_to_json,
load_json=custom_from_json,
)

state_dict() 方法则以字典的形式,返回该类型的状态信息,它的核心逻辑是:对于 _module_dict 中的子模块,递归调用 state_dict() 方法获取其状态信息;而对于 _attribute_dict 中的普通属性,根据是否有自定义的序列化方法,选择使用自定义方法或直接返回:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def state_dict(self) -> dict:
state = {}

# 1. 递归收集子模块状态
for key in self._module_dict:
attr = getattr(self, key)
state[key] = attr.state_dict() # 递归调用

# 2. 收集注册属性状态
for key in self._attribute_dict:
attr = getattr(self, key)
to_json = self._attribute_dict[key].to_json
# 如果没有提供 json 序列化函数,直接记录属性值
state[key] = to_json(attr) if to_json else attr

return state

load_state_dict() 方法则从状态字典中恢复出原始的类型对象:

1
2
3
4
5
6
7
8
9
10
def load_state_dict(self, state_dict: dict, strict: bool = True):
# 1. 递归加载子模块状态
for key in self._module_dict:
self._module_dict[key].load_state_dict(state_dict[key])

# 2. 加载注册属性状态
for key in self._attribute_dict:
from_json = self._attribute_dict[key].load_json
value = from_json(state_dict[key]) if from_json else state_dict[key]
setattr(self, key, value)

StateModule 为自定义类型提供了状态的控制能力(即控制哪些属性需要作为状态进行跟踪),但它只是提供了序列化(state_dict)和反序列化(load_state_dict)的基本功能,为了将状态持久化,还需要依赖 AgentScope 所提供的 Session 模块。关于 Session 模块,我们在后续文章继续介绍。

_AgentMeta 元类

分析完 StateModule 实现之后,我们再来分析 AgentBase 类申明中另一个值得注意点,即 AgentBase_AgentMeta 作为元类:

1
2
class AgentBase(StateModule, metaclass=_AgentMeta):
......

Python 中的元类是 构建类的类,负责控制 类的创建和行为。普通类用来指导构建实例对象,而元类则是用来指导如何构建类本身(Python 中类本身也是一种对象)。关于 Python 元类的基础知识,可以参考 流畅的 Python 第 2 版(24):类元编程

_AgentMeta 元类的定义如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
class _AgentMeta(type):
# 如何创建一个 Agent 类
# mcs 是元类,name 是类名,bases 是父类,attrs 是类的属性
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
"""Wrap the agent's functions with hooks."""

# 对这些函数,添加 hook 逻辑
for func_name in [
"reply",
"print",
"observe",
]:
if func_name in attrs:
attrs[func_name] = _wrap_with_hooks(attrs[func_name])

return super().__new__(mcs, name, bases, attrs)
  • _AgentMeta 是一个元类,因为它集成自 type,因此具有创建类对象的能力,因此它是一个元类
  • __new__ 方法定义了如何创建类对象,即这里的 AgentBase
  • __new__ 方法中检查了 AgentBase 类中的三个属性:replyprintobserve,为这些属性添加 hooks 逻辑

简单来说,_AgentMeta 就是为 AgentBase 类及其子类自动添加了一些钩子逻辑。因为 AgentBase 类在构建时会执行 _AgentMeta.__new__,而 __new__ 方法则检查这些创建的类中是否有 replyprintobserve 这三个方法,如果有,则将这些方法包装成带有钩子逻辑的新方法。

_wrap_with_hooks 的代码虽然有些多,但是其核心逻辑就是根据当前 hook 的函数名,判断 AgentBase 类中是否定义了 _pre_{method.__name__}_hooks_post_{method.__name__}_hooks 字典。如果存在,则会从这些字典中取出相关的 hooks 方法,并在调用原始方法的前后,分别执行这些 hooks 函数。hooks 方法支持实例方法和类方法:

1
2
3
4
5
6
assert (
hasattr(self, f"_instance_pre_{func_name}_hooks")
and hasattr(self, f"_instance_post_{func_name}_hooks")
and hasattr(self.__class__, f"_class_pre_{func_name}_hooks")
and hasattr(self.__class__, f"_class_post_{func_name}_hooks")
), f"Hooks for {func_name} not found in {self.__class__.__name__}"

_warp_with_hooks 方法的简化逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def _wrap_with_hooks(original_func):
"""装饰器:在函数前后执行 hooks"""
func_name = original_func.__name__.replace("_", "")

@wraps(original_func)
async def async_wrapper(self, *args, **kwargs):
# 1. 参数归一化
normalized_kwargs = _normalize_to_kwargs(original_func, self, *args, **kwargs)

# 2. 执行 pre-hooks(可修改参数)
pre_hooks = list(self._instance_pre_hooks.values()) + \
list(self.__class__._class_pre_hooks.values())
for pre_hook in pre_hooks:
modified = await pre_hook(self, deepcopy(normalized_kwargs))
if modified:
normalized_kwargs = modified

# 3. 执行原函数
output = await original_func(self, **normalized_kwargs)

# 4. 执行 post-hooks(可修改输出)
post_hooks = list(self._instance_post_hooks.values()) + \
list(self.__class__._class_post_hooks.values())
for post_hook in post_hooks:
modified = await post_hook(self, deepcopy(normalized_kwargs), output)
if modified:
output = modified

return output
return async_wrapper

所以总结一下,_AgentMeta 实现了 hook 约定:

  • 如果你的类中定义了相关的字典(例如 _instance_pre_{func_name}_hooks 等)来保存 hooks 函数,则 _AgentMeta 可以帮你自动地在执行相关方法时,完成相关 hooks 的调用
  • _AgentMeta 支持为 replyprintobserve 这些方法添加 hooks 逻辑,hook 执行时机可以是 pre 或者 post
  • hook 函数是实例方法或者类方法

AgentBase

分析完 AgentBase 定义中所指定的基类 StateModule 以及其元类 metaclass,接下来我们再来看 AgentBase类的主体实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
class AgentBase(StateModule, metaclass=_AgentMeta):
"""Base class for asynchronous agents."""

id: str

supported_hook_types: list[str] = [
"pre_reply",
"post_reply",
"pre_print",
"post_print",
"pre_observe",
"post_observe",
]

# 类级别的 hooks 函数
_class_pre_reply_hooks: dict = OrderedDict()
_class_post_reply_hooks: dict = OrderedDict()
_class_pre_print_hooks: dict = OrderedDict()
_class_post_print_hooks: dict = OrderedDict()
_class_pre_observe_hooks: dict = OrderedDict()
_class_post_observe_hooks: dict = OrderedDict()

# 初始化
def __init__(self) -> None:
"""Initialize the agent."""
super().__init__()

self.id = shortuuid.uuid()

# The replying task and identify of the current replying
self._reply_task: Task | None = None
self._reply_id: str | None = None

# 实例级别的 hooks 函数
# Initialize the instance-level hooks
self._instance_pre_print_hooks = OrderedDict()
self._instance_post_print_hooks = OrderedDict()

self._instance_pre_reply_hooks = OrderedDict()
self._instance_post_reply_hooks = OrderedDict()

self._instance_pre_observe_hooks = OrderedDict()
self._instance_post_observe_hooks = OrderedDict()

# 用于保存流式打印的累计数据
self._stream_prefix = {}

# 定义该 agent 的 subscribers,subscriber 可以通过 observe 方法接收到该 agent 的消息
# key 是 MsgHub 的 id
# value 是一系列是的 subagent
self._subscribers: dict[str, list[AgentBase]] = {}

# disable 控制台输出
self._disable_console_output: bool = (
os.getenv(
"AGENTSCOPE_DISABLE_CONSOLE_OUTPUT",
"false",
).lower()
== "true"
)

self._disable_msg_queue: bool = True
self.msg_queue = None
  • 每个 Agent 都会有一个 id
  • 使用字典保存类级别和实例级别的 hooks 函数,这块和我们对 _AgentMeta 的分析是一致的

AgentBase 定义了如下方法:

方法 作用 实现状态
reply() 生成回复 抽象方法,子类实现
observe() 观察消息 抽象方法,子类实现
print() 打印消息 已实现,支持流式输出
__call__() 调用入口 已实现,处理中断和广播
interrupt() 中断执行 已实现
handle_interrupt() 中断处理 抽象方法

AgentBase 实现了 __call__() 方法,因此可以直接调用 AgentBase 的实例,以处理用户的输入消息:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
async def __call__(self, *args, **kwargs) -> Msg:
self._reply_id = shortuuid.uuid()
reply_msg = None

try:
self._reply_task = asyncio.current_task()
reply_msg = await self.reply(*args, **kwargs) # 执行回复

except asyncio.CancelledError:
reply_msg = await self.handle_interrupt(*args, **kwargs) # 处理中断

finally:
if reply_msg:
await self._broadcast_to_subscribers(reply_msg) # 广播给订阅者
self._reply_task = None

return reply_msg
  • __call__ 和核心逻辑是定义了处理输入消息的总体流程,包括调用 reply() 方法来处理消息,调用 handle_interrupt() 来处理中断,并将回复的消息广播给所有订阅者
  • reply()handle_interrupt() 都是抽象方法,需要子类实现,这些才是 Agent 的核心业务逻辑
  • AgentBase 支持 智能体回复消息 的广播机制,用于将回复消息广播给其他智能体,每个智能体都需要实现 observe() 方法来接收并处理消息
1
2
3
4
5
6
7
8
9
async def _broadcast_to_subscribers(
self,
msg: Msg | list[Msg] | None,
) -> None:
# 调用所有订阅者的 observe 方法
"""Broadcast the message to all subscribers."""
for subscribers in self._subscribers.values():
for subscriber in subscribers:
await subscriber.observe(msg)

reset_subscribers()reset_subscribers() 方法用于管理当前智能体的订阅者:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def reset_subscribers(
self,
msghub_name: str,
subscribers: list["AgentBase"],
) -> None:
# 重新设置订阅者,排除当前代理自身
self._subscribers[msghub_name] = [_ for _ in subscribers if _ != self]

def remove_subscribers(self, msghub_name: str) -> None:
if msghub_name not in self._subscribers:
logger.warning(
"MsgHub named '%s' not found",
msghub_name,
)
else:
self._subscribers.pop(msghub_name)

AgentBase 提供了 print() 方法,用来流式输出 Agent 执行过程中的消息,其核心逻辑如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
async def print(self, msg: Msg, last: bool = True, speech: AudioBlock = None):
"""支持流式打印消息"""

# 1. 消息队列支持(用于外部消费)
if not self._disable_msg_queue:
await self.msg_queue.put((deepcopy(msg), last, speech))
await asyncio.sleep(0) # 让出控制权

# 2. 控制台输出
if not self._disable_console_output:
for block in msg.get_content_blocks():
if block["type"] == "text":
self._print_text_block(msg.id, block["text"])
elif block["type"] == "thinking":
self._print_text_block(msg.id, block["thinking"])
# ... 处理其他 block 类型

# 3. 音频播放
if speech:
self._process_audio_block(msg.id, speech)

if last and msg.id in self._stream_prefix:
# 1. 关闭音频播放器
if "audio" in self._stream_prefix[msg.id]:
player, _ = self._stream_prefix[msg.id]["audio"]
player.close()

# 2. 移除缓存
stream_prefix = self._stream_prefix.pop(msg.id)

# 3. 补换行符
if "text" in stream_prefix and not stream_prefix["text"].endswith("\n"):
print()

print() 的函数入口处,支持将 Msg 额外保存到 self.msg_queue,这样允许其他组件消费这些消息。AgentBase.set_msg_queue_enabled() 方法允许设置所使用的 msg_queue,默认为空。

print() 的一个特点是支持流式增量输出,其通过 _stream_prefix 记录每个消息已经输出内容:

1
2
3
4
5
6
self._stream_prefix = {
msg_id: {
"text": "已打印的文本内容",
"audio": (player_object, "已播放的base64数据"),
}
}

之后只会输出增量部分,我们可以通过 _print_text_block() 查看其如何实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def _print_text_block(
self,
msg_id: str,
name_prefix: str,
text_content: str,
thinking_and_text_to_print: list[str],
) -> None:
thinking_and_text_to_print.append(
f"{name_prefix}: {text_content}",
)
# The accumulated text and thinking blocks to print
to_print = "\n".join(thinking_and_text_to_print)

# 添加 msg_id 到 _stream_prefix 中
# The text prefix that has been printed
if msg_id not in self._stream_prefix:
self._stream_prefix[msg_id] = {}

# 获取前缀
text_prefix = self._stream_prefix[msg_id].get("text", "")

# 答应新的内容
# Only print when there is new text content
if len(to_print) > len(text_prefix):
print(to_print[len(text_prefix) :], end="")

# Save the printed text prefix
self._stream_prefix[msg_id]["text"] = to_print

_process_audio_block() 方法则用于处理音频消息的输出,可以处理 URL 类型和 Base64 字符串类型的音频数据。这里就不再展开其代码。

AgentBase 还提供了几个工具函数,用来简化 Hook 的注册:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# 注册实例级别的钩子
def register_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
if not isinstance(self, AgentBase):
raise TypeError(
"The register_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
hooks[hook_name] = hook

# 移除实例级别的钩子
def remove_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
if not isinstance(self, AgentBase):
raise TypeError(
"The remove_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
if hook_name in hooks:
del hooks[hook_name]
else:
raise ValueError(
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
f"{self.__class__.__name__} instance.",
)

# 注册类级别的钩子
@classmethod
def register_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"

hooks = getattr(cls, f"_class_{hook_type}_hooks")
hooks[hook_name] = hook

# 移除类级别的钩子
@classmethod
def remove_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
......

# 清除类级别的钩子
@classmethod
def clear_class_hooks(
cls,
hook_type: AgentHookTypes | None = None,
) -> None:
......

# 清除实例级别的钩子
def clear_instance_hooks(
self,
hook_type: AgentHookTypes | None = None,
) -> None:
......

以上我们就完成了对 AgentBase 类的主体介绍,主要是其 __call__() 定义了 Agent 的最基本的流程,包括 reply()handle_interrupt()observe()。另外提供了 print() 方法来流式输出消息。

小结

这篇文章我们学习了 AgentBase 类,它是所有 Agent 实现的基类,它本身继承自 StateModule 以支持状态信息管理,并使用 _AgentMeta 作为元类以支持 Hooks 的执行,另外详细介绍了 AgentBase 本身的实现,包括各种 hooks 的管理、__call__ 方法和 print 方法等。