从本篇文章开始,我们将正式深入剖析 AgentScope 框架中 Agent 模块的核心实现源码,首先聚焦于 Agent 底层的基础设计,逐一解析 StateModule、_AgentMeta 元类与 AgentBase 基类等关键组件。
Agent 模块
AgentScope 的 agent 模块是整个框架的核心,定义了 Agent 的基础架构、行为模式和扩展能力。模块采用分层设计,从基类到具体实现,构建了一套完整的 Agent 开发体系。Agent 模块的代码位于 src/agentscope/agent/ 目录下,文件列表如下:
1 2 3 4 5 6 7 8 9 10 11 src/agentscope/agent/ ├── __init__.py ├── _utils.py ├── _agent_meta.py ├── _agent_base.py ├── _react_agent_base.py ├── _react_agent.py ├── _user_agent.py ├── _user_input.py ├── _a2a_agent.py └── _realtime_agent.py
AgentBase 是所有 Agent 的基类,定义了所有 Agent 的公共接口和属性:
1 2 class AgentBase (StateModule, metaclass=_AgentMeta): ......
可以看到 AgentBase 又继承自 StateModule 类,其实不仅 AgentBase 继承自 StateModule,之前介绍的 Toolkit、PlanNotebook、PlanStorageBase、LongTermMemoryBase 等类型都继承自 StateModule,那这个 StateModule 类到底起什么作用呢?我们详细介绍一下。
StateModule
在构建 AI Agent 系统时,存在一个核心需求:状态持久化 。典型场景:
Agent 执行长任务时崩溃,需要恢复现场
Agent 需要在多次对话间保持记忆
Agent 需要保存/加载历史会话
StateModule 提供了一套声明式状态管理机制 ,提供统一的序列化/反序列化机制,支持:
自动状态追踪 :自动追踪继承自 StateModule 的子属性
手动状态注册 :支持注册普通属性并提供自定义序列化方法
嵌套状态管理 :支持多层嵌套对象的状态序列化和恢复
Session 管理 :与 Session 模块配合实现应用级状态持久化
StateModule 模块的代码位于 src/agentscope/module/_state_module.py。AgentScope 中,以下类型都继承自 StateModule:
1 2 3 4 5 6 7 8 9 StateModule (基类) │ ├── AgentBase ├── MemoryBase ├── LongTermMemoryBase ├── Toolkit ├── PlanNotebook ├── PlanStorageBase └── RealtimeAgent
StateModule 用法示例
首先我们将通过一些简单例子来展示 StateModule 的用法,以对 StateModule 有个直观的认识。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 from collections import OrderedDictfrom agentscope.module import StateModuleclass Counter (StateModule ): def __init__ (self ): super ().__init__() self.count = 0 self.temp = "not tracked" self.register_state("count" ) c = Counter() c.count = 100 c.temp = "new value" state = c.state_dict() new_c = Counter() new_c.load_state_dict(state) print (f"【1】state={state} " )print (f" count={new_c.count} (registered, restored)" )print (f" temp='{new_c.temp} ' (not registered, NOT restored)" )class Memory (StateModule ): def __init__ (self ): super ().__init__() self.msgs = [] self.register_state("msgs" ) class Agent (StateModule ): def __init__ (self ): super ().__init__() self.memory = Memory() agent = Agent() agent.memory.msgs.append("hello" ) state = agent.state_dict() new_agent = Agent() new_agent.load_state_dict(state) print (f"【2】state={state} " )print (f" restored: msgs={new_agent.memory.msgs} " )class User (StateModule ): def __init__ (self ): super ().__init__() self.prefs = OrderedDict() self.register_state( "prefs" , custom_to_json=lambda x: dict (x), custom_from_json=lambda x: OrderedDict(x), ) user = User() user.prefs["lang" ] = "zh" state = user.state_dict() new_user = User() new_user.load_state_dict(state) print (f"【3】state={state} " )print ( f" restored: prefs={dict (new_user.prefs)} , type={type (new_user.prefs).__name__} " )
1 2 3 4 5 6 7 8 【1】state={'count' : 100} count=100 (registered, restored) temp='not tracked' (not registered, NOT restored) 【2】state={'memory' : {'msgs' : ['hello' ]}} restored: msgs=['hello' ] 【3】state={'prefs' : {'lang' : 'zh' }} restored: prefs={'lang' : 'zh' }, type =OrderedDict
对于继承自 StateModule 的类,通过 state_dict() 获取该对象的所有 state 信息(序列化)、而通过 load_state_dict 从 state 信息中恢复该对象的状态(反序列化)
在第一个例子中,对于 Counter 类,想要将 counter 属性作为一个可持久化的状态,必须通过 register_state("counter"),而 temp 属性没有被注册为状态,因此不会被状态跟踪
对于第二个例子,对于 Agent 类型,其继承自 StateModule 类,因此可以对 Agent 类型进行状态跟踪。而其内部的 memory 成员,由于其类型 Memory 本身继承自 StateModule,因此 memory 自动就会被状态跟踪,无需再调用 register_state
第三个例子则展示了,我们可以在 register_state() 中自定义序列化和反序列化方法
这个例子展示了我们所说的 StateModule 的两个关键能力:即
自动状态追踪:自动追踪继承自 StateModule 的子属性
手动状态注册:手动通过 register_state() 注册普通属性,并可以提供自定义序列化方法
接下来我们这个例子,则展示了 StateModule 的 嵌套状态管理,支持多层嵌套对象的状态序列化和恢复 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 from agentscope.module import StateModuleclass ToolHistory (StateModule ): def __init__ (self ): super ().__init__() self.calls = [] self.register_state("calls" ) class ToolKit (StateModule ): def __init__ (self ): super ().__init__() self.history = ToolHistory() class Agent (StateModule ): def __init__ (self, name: str ): super ().__init__() self.name = name self.register_state("name" ) self.toolkit = ToolKit() agent = Agent("Assistant" ) agent.toolkit.history.calls.append({"tool" : "search" , "args" : {"q" : "test" }}) new_agent = Agent("temp" ) new_agent.load_state_dict(agent.state_dict()) print (f"【Nested】state={state} " )print (f" name={new_agent.name} " )print (f" calls={new_agent.toolkit.history.calls} " )
1 2 3 4 5 【Nested】state={'toolkit' : {'history' : {'calls' : [{'tool' : 'search' , 'args' : {'q' : 'test' }}]}}, 'name' : 'Assistant' } name=Assistant calls=[{'tool' : 'search' , 'args' : {'q' : 'test' }}]
StateModule 的实现
接下来我们再来看 StateModule 是如何实现的。StateModule 的类型定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 class StateModule : def __init__ (self ): self._module_dict = OrderedDict() self._attribute_dict = OrderedDict() def __setattr__ (self, key: str , value: Any ) -> None : if isinstance (value, StateModule): self._module_dict[key] = value super ().__setattr__(key, value)
通过 _module_dict 来保存所有本身是 StateModule 的属性,而通过 _attribute_dict 来保存所有手动注册的普通属性
上文说过,对于本身就是 StateModule 类型的属性会自动进行状态跟踪,就是依靠 __setattr__ 方法,在设置属性时,如果发现是 StateModule 类型,则将其添加到 _module_dict 中
register_state() 方法用于手动状态注册:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 def register_state ( self, attr_name: str , custom_to_json: Callable [[Any ], JSONSerializableObject] | None = None , custom_from_json: Callable [[JSONSerializableObject], Any ] | None = None , ) -> None : attr = getattr (self, attr_name) if custom_to_json is None : try : json.dumps(attr) except Exception as e: raise TypeError( f"Attribute '{attr_name} ' is not JSON serializable. " "Please provide a custom function to convert the " "attribute to a JSON-serializable format." , ) from e if attr_name in self._module_dict: raise ValueError( f"Attribute `{attr_name} ` is already registered as a module. " , ) self._attribute_dict[attr_name] = _JSONSerializeFunction( to_json=custom_to_json, load_json=custom_from_json, )
state_dict() 方法则以字典的形式,返回该类型的状态信息,它的核心逻辑是:对于 _module_dict 中的子模块,递归调用 state_dict() 方法获取其状态信息;而对于 _attribute_dict 中的普通属性,根据是否有自定义的序列化方法,选择使用自定义方法或直接返回:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def state_dict (self ) -> dict : state = {} for key in self._module_dict: attr = getattr (self, key) state[key] = attr.state_dict() for key in self._attribute_dict: attr = getattr (self, key) to_json = self._attribute_dict[key].to_json state[key] = to_json(attr) if to_json else attr return state
load_state_dict() 方法则从状态字典中恢复出原始的类型对象:
1 2 3 4 5 6 7 8 9 10 def load_state_dict (self, state_dict: dict , strict: bool = True ): for key in self._module_dict: self._module_dict[key].load_state_dict(state_dict[key]) for key in self._attribute_dict: from_json = self._attribute_dict[key].load_json value = from_json(state_dict[key]) if from_json else state_dict[key] setattr (self, key, value)
StateModule 为自定义类型提供了状态的控制能力(即控制哪些属性需要作为状态进行跟踪),但它只是提供了序列化(state_dict)和反序列化(load_state_dict)的基本功能,为了将状态持久化,还需要依赖 AgentScope 所提供的 Session 模块。关于 Session 模块,我们在后续文章继续介绍。
分析完 StateModule 实现之后,我们再来分析 AgentBase 类申明中另一个值得注意点,即 AgentBase 将 _AgentMeta 作为元类:
1 2 class AgentBase (StateModule, metaclass=_AgentMeta): ......
Python 中的元类是 构建类的类,负责控制 类的创建和行为。普通类用来指导构建实例对象,而元类则是用来指导如何构建类本身(Python 中类本身也是一种对象)。关于 Python 元类的基础知识,可以参考 流畅的 Python 第 2 版(24):类元编程 。
_AgentMeta 元类的定义如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 class _AgentMeta (type ): def __new__ (mcs, name: Any , bases: Any , attrs: Dict ) -> Any : """Wrap the agent's functions with hooks.""" for func_name in [ "reply" , "print" , "observe" , ]: if func_name in attrs: attrs[func_name] = _wrap_with_hooks(attrs[func_name]) return super ().__new__(mcs, name, bases, attrs)
_AgentMeta 是一个元类,因为它集成自 type,因此具有创建类对象的能力,因此它是一个元类
__new__ 方法定义了如何创建类对象,即这里的 AgentBase 类
__new__ 方法中检查了 AgentBase 类中的三个属性:reply、print、observe,为这些属性添加 hooks 逻辑
简单来说,_AgentMeta 就是为 AgentBase 类及其子类自动添加了一些钩子逻辑。因为 AgentBase 类在构建时会执行 _AgentMeta.__new__,而 __new__ 方法则检查这些创建的类中是否有 reply、print、observe 这三个方法,如果有,则将这些方法包装成带有钩子逻辑的新方法。
_wrap_with_hooks 的代码虽然有些多,但是其核心逻辑就是根据当前 hook 的函数名,判断 AgentBase 类中是否定义了 _pre_{method.__name__}_hooks 和 _post_{method.__name__}_hooks 字典。如果存在,则会从这些字典中取出相关的 hooks 方法,并在调用原始方法的前后,分别执行这些 hooks 函数。hooks 方法支持实例方法和类方法:
1 2 3 4 5 6 assert ( hasattr (self, f"_instance_pre_{func_name} _hooks" ) and hasattr (self, f"_instance_post_{func_name} _hooks" ) and hasattr (self.__class__, f"_class_pre_{func_name} _hooks" ) and hasattr (self.__class__, f"_class_post_{func_name} _hooks" ) ), f"Hooks for {func_name} not found in {self.__class__.__name__} "
_warp_with_hooks 方法的简化逻辑如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 def _wrap_with_hooks (original_func ): """装饰器:在函数前后执行 hooks""" func_name = original_func.__name__.replace("_" , "" ) @wraps(original_func ) async def async_wrapper (self, *args, **kwargs ): normalized_kwargs = _normalize_to_kwargs(original_func, self, *args, **kwargs) pre_hooks = list (self._instance_pre_hooks.values()) + \ list (self.__class__._class_pre_hooks.values()) for pre_hook in pre_hooks: modified = await pre_hook(self, deepcopy(normalized_kwargs)) if modified: normalized_kwargs = modified output = await original_func(self, **normalized_kwargs) post_hooks = list (self._instance_post_hooks.values()) + \ list (self.__class__._class_post_hooks.values()) for post_hook in post_hooks: modified = await post_hook(self, deepcopy(normalized_kwargs), output) if modified: output = modified return output return async_wrapper
所以总结一下,_AgentMeta 实现了 hook 约定:
如果你的类中定义了相关的字典(例如 _instance_pre_{func_name}_hooks 等)来保存 hooks 函数,则 _AgentMeta 可以帮你自动地在执行相关方法时,完成相关 hooks 的调用
_AgentMeta 支持为 reply、print、observe 这些方法添加 hooks 逻辑,hook 执行时机可以是 pre 或者 post
hook 函数是实例方法或者类方法
AgentBase
分析完 AgentBase 定义中所指定的基类 StateModule 以及其元类 metaclass,接下来我们再来看 AgentBase类的主体实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 class AgentBase (StateModule, metaclass=_AgentMeta): """Base class for asynchronous agents.""" id : str supported_hook_types: list [str ] = [ "pre_reply" , "post_reply" , "pre_print" , "post_print" , "pre_observe" , "post_observe" , ] _class_pre_reply_hooks: dict = OrderedDict() _class_post_reply_hooks: dict = OrderedDict() _class_pre_print_hooks: dict = OrderedDict() _class_post_print_hooks: dict = OrderedDict() _class_pre_observe_hooks: dict = OrderedDict() _class_post_observe_hooks: dict = OrderedDict() def __init__ (self ) -> None : """Initialize the agent.""" super ().__init__() self.id = shortuuid.uuid() self._reply_task: Task | None = None self._reply_id: str | None = None self._instance_pre_print_hooks = OrderedDict() self._instance_post_print_hooks = OrderedDict() self._instance_pre_reply_hooks = OrderedDict() self._instance_post_reply_hooks = OrderedDict() self._instance_pre_observe_hooks = OrderedDict() self._instance_post_observe_hooks = OrderedDict() self._stream_prefix = {} self._subscribers: dict [str , list [AgentBase]] = {} self._disable_console_output: bool = ( os.getenv( "AGENTSCOPE_DISABLE_CONSOLE_OUTPUT" , "false" , ).lower() == "true" ) self._disable_msg_queue: bool = True self.msg_queue = None
每个 Agent 都会有一个 id
使用字典保存类级别和实例级别的 hooks 函数,这块和我们对 _AgentMeta 的分析是一致的
AgentBase 定义了如下方法:
方法
作用
实现状态
reply()
生成回复
抽象方法,子类实现
observe()
观察消息
抽象方法,子类实现
print()
打印消息
已实现,支持流式输出
__call__()
调用入口
已实现,处理中断和广播
interrupt()
中断执行
已实现
handle_interrupt()
中断处理
抽象方法
AgentBase 实现了 __call__() 方法,因此可以直接调用 AgentBase 的实例,以处理用户的输入消息:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 async def __call__ (self, *args, **kwargs ) -> Msg: self._reply_id = shortuuid.uuid() reply_msg = None try : self._reply_task = asyncio.current_task() reply_msg = await self.reply(*args, **kwargs) except asyncio.CancelledError: reply_msg = await self.handle_interrupt(*args, **kwargs) finally : if reply_msg: await self._broadcast_to_subscribers(reply_msg) self._reply_task = None return reply_msg
__call__ 和核心逻辑是定义了处理输入消息的总体流程,包括调用 reply() 方法来处理消息,调用 handle_interrupt() 来处理中断,并将回复的消息广播给所有订阅者
reply()、handle_interrupt() 都是抽象方法,需要子类实现,这些才是 Agent 的核心业务逻辑
AgentBase 支持 智能体回复消息 的广播机制,用于将回复消息广播给其他智能体,每个智能体都需要实现 observe() 方法来接收并处理消息
1 2 3 4 5 6 7 8 9 async def _broadcast_to_subscribers ( self, msg: Msg | list [Msg] | None , ) -> None : """Broadcast the message to all subscribers.""" for subscribers in self._subscribers.values(): for subscriber in subscribers: await subscriber.observe(msg)
reset_subscribers() 和 reset_subscribers() 方法用于管理当前智能体的订阅者:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 def reset_subscribers ( self, msghub_name: str , subscribers: list ["AgentBase" ], ) -> None : self._subscribers[msghub_name] = [_ for _ in subscribers if _ != self] def remove_subscribers (self, msghub_name: str ) -> None : if msghub_name not in self._subscribers: logger.warning( "MsgHub named '%s' not found" , msghub_name, ) else : self._subscribers.pop(msghub_name)
AgentBase 提供了 print() 方法,用来流式输出 Agent 执行过程中的消息,其核心逻辑如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 async def print (self, msg: Msg, last: bool = True , speech: AudioBlock = None ): """支持流式打印消息""" if not self._disable_msg_queue: await self.msg_queue.put((deepcopy(msg), last, speech)) await asyncio.sleep(0 ) if not self._disable_console_output: for block in msg.get_content_blocks(): if block["type" ] == "text" : self._print_text_block(msg.id , block["text" ]) elif block["type" ] == "thinking" : self._print_text_block(msg.id , block["thinking" ]) if speech: self._process_audio_block(msg.id , speech) if last and msg.id in self._stream_prefix: if "audio" in self._stream_prefix[msg.id ]: player, _ = self._stream_prefix[msg.id ]["audio" ] player.close() stream_prefix = self._stream_prefix.pop(msg.id ) if "text" in stream_prefix and not stream_prefix["text" ].endswith("\n" ): print ()
print() 的函数入口处,支持将 Msg 额外保存到 self.msg_queue,这样允许其他组件消费这些消息。AgentBase.set_msg_queue_enabled() 方法允许设置所使用的 msg_queue,默认为空。
print() 的一个特点是支持流式增量输出,其通过 _stream_prefix 记录每个消息已经输出内容:
1 2 3 4 5 6 self._stream_prefix = { msg_id: { "text" : "已打印的文本内容" , "audio" : (player_object, "已播放的base64数据" ), } }
之后只会输出增量部分,我们可以通过 _print_text_block() 查看其如何实现:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 def _print_text_block ( self, msg_id: str , name_prefix: str , text_content: str , thinking_and_text_to_print: list [str ], ) -> None : thinking_and_text_to_print.append( f"{name_prefix} : {text_content} " , ) to_print = "\n" .join(thinking_and_text_to_print) if msg_id not in self._stream_prefix: self._stream_prefix[msg_id] = {} text_prefix = self._stream_prefix[msg_id].get("text" , "" ) if len (to_print) > len (text_prefix): print (to_print[len (text_prefix) :], end="" ) self._stream_prefix[msg_id]["text" ] = to_print
_process_audio_block() 方法则用于处理音频消息的输出,可以处理 URL 类型和 Base64 字符串类型的音频数据。这里就不再展开其代码。
AgentBase 还提供了几个工具函数,用来简化 Hook 的注册:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 def register_instance_hook ( self, hook_type: AgentHookTypes, hook_name: str , hook: Callable , ) -> None : if not isinstance (self, AgentBase): raise TypeError( "The register_instance_hook method should be called on an " f"instance of AsyncAgentBase, but got {self} of " f"type {type (self)} ." , ) hooks = getattr (self, f"_instance_{hook_type} _hooks" ) hooks[hook_name] = hook def remove_instance_hook ( self, hook_type: AgentHookTypes, hook_name: str , ) -> None : if not isinstance (self, AgentBase): raise TypeError( "The remove_instance_hook method should be called on an " f"instance of AsyncAgentBase, but got {self} of " f"type {type (self)} ." , ) hooks = getattr (self, f"_instance_{hook_type} _hooks" ) if hook_name in hooks: del hooks[hook_name] else : raise ValueError( f"Hook '{hook_name} ' not found in '{hook_type} ' hooks of " f"{self.__class__.__name__} instance." , ) @classmethod def register_class_hook ( cls, hook_type: AgentHookTypes, hook_name: str , hook: Callable , ) -> None : assert ( hook_type in cls.supported_hook_types ), f"Invalid hook type: {hook_type} " hooks = getattr (cls, f"_class_{hook_type} _hooks" ) hooks[hook_name] = hook @classmethod def remove_class_hook ( cls, hook_type: AgentHookTypes, hook_name: str , ) -> None : ...... @classmethod def clear_class_hooks ( cls, hook_type: AgentHookTypes | None = None , ) -> None : ...... def clear_instance_hooks ( self, hook_type: AgentHookTypes | None = None , ) -> None : ......
以上我们就完成了对 AgentBase 类的主体介绍,主要是其 __call__() 定义了 Agent 的最基本的流程,包括 reply()、handle_interrupt()、observe()。另外提供了 print() 方法来流式输出消息。
小结
这篇文章我们学习了 AgentBase 类,它是所有 Agent 实现的基类,它本身继承自 StateModule 以支持状态信息管理,并使用 _AgentMeta 作为元类以支持 Hooks 的执行,另外详细介绍了 AgentBase 本身的实现,包括各种 hooks 的管理、__call__ 方法和 print 方法等。