wd666/orchestrator/debate_manager.py

169 lines
5.3 KiB
Python
Raw Normal View History

2026-01-07 11:02:05 +08:00
"""
辩论管理器 - 编排多 Agent 辩论流程
"""
from typing import List, Generator, Callable
from dataclasses import dataclass
from agents.base_agent import BaseAgent
from agents.agent_profiles import get_agent_profile
from utils.llm_client import LLMClient
import config
@dataclass
class DebateConfig:
"""辩论配置"""
topic: str
context: str = ""
agent_ids: List[str] = None
max_rounds: int = 2
agent_clients: dict = None # Map[agent_id, LLMClient]
2026-01-07 14:04:52 +08:00
language: str = "Chinese"
2026-01-07 11:02:05 +08:00
@dataclass
class SpeechRecord:
"""发言记录"""
agent_id: str
agent_name: str
emoji: str
content: str
round_num: int
class DebateManager:
"""辩论管理器"""
def __init__(self, llm_client: LLMClient = None):
"""
初始化辩论管理器
Args:
llm_client: LLM 客户端实例
"""
self.llm_client = llm_client or LLMClient()
self.agents: List[BaseAgent] = []
self.speech_records: List[SpeechRecord] = []
self.current_round = 0
def setup_debate(self, debate_config: DebateConfig) -> None:
"""
设置辩论
Args:
debate_config: 辩论配置
"""
self.config = debate_config
self.agents = []
self.speech_records = []
self.current_round = 0
# 创建参与的 Agent
for agent_id in debate_config.agent_ids:
# Check if specific client is provided in config, else use default
client = self.llm_client
if hasattr(debate_config, 'agent_clients') and debate_config.agent_clients and agent_id in debate_config.agent_clients:
client = debate_config.agent_clients[agent_id]
2026-01-07 14:04:52 +08:00
agent = BaseAgent(agent_id, client, language=debate_config.language)
2026-01-07 11:02:05 +08:00
self.agents.append(agent)
def run_debate_stream(
self,
on_speech_start: Callable = None,
on_speech_chunk: Callable = None,
on_speech_end: Callable = None,
on_round_end: Callable = None
) -> Generator[dict, None, None]:
"""
运行辩论流式
Args:
on_speech_start: 发言开始回调
on_speech_chunk: 发言片段回调
on_speech_end: 发言结束回调
on_round_end: 轮次结束回调
Yields:
dict: 事件信息
"""
for round_num in range(1, self.config.max_rounds + 1):
self.current_round = round_num
yield {
"type": "round_start",
"round": round_num,
"total_rounds": self.config.max_rounds
}
for agent in self.agents:
# 获取之前的发言(排除自己)
previous_speeches = [
{
"name": r.agent_name,
"emoji": r.emoji,
"content": r.content
}
for r in self.speech_records
if r.agent_id != agent.agent_id
]
yield {
"type": "speech_start",
"agent_id": agent.agent_id,
"agent_name": agent.name,
"emoji": agent.emoji,
"model_name": agent.model_name,
2026-01-07 11:02:05 +08:00
"round": round_num
}
# 流式生成发言
full_content = ""
for chunk in agent.generate_response(
topic=self.config.topic,
context=self.config.context,
previous_speeches=previous_speeches,
round_num=round_num
):
full_content += chunk
yield {
"type": "speech_chunk",
"agent_id": agent.agent_id,
"chunk": chunk
}
# 保存发言记录
record = SpeechRecord(
agent_id=agent.agent_id,
agent_name=agent.name,
emoji=agent.emoji,
content=full_content,
round_num=round_num
)
self.speech_records.append(record)
yield {
"type": "speech_end",
"agent_id": agent.agent_id,
"content": full_content
}
yield {
"type": "round_end",
"round": round_num
}
yield {"type": "debate_end"}
def get_all_speeches(self) -> List[SpeechRecord]:
"""获取所有发言记录"""
return self.speech_records
def get_speeches_by_round(self, round_num: int) -> List[SpeechRecord]:
"""获取指定轮次的发言"""
return [r for r in self.speech_records if r.round_num == round_num]
def get_speeches_by_agent(self, agent_id: str) -> List[SpeechRecord]:
"""获取指定 Agent 的所有发言"""
return [r for r in self.speech_records if r.agent_id == agent_id]