2026-01-07 11:02:05 +08:00
|
|
|
|
"""
|
|
|
|
|
|
辩论管理器 - 编排多 Agent 辩论流程
|
|
|
|
|
|
"""
|
|
|
|
|
|
from typing import List, Generator, Callable
|
|
|
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
|
|
|
from agents.base_agent import BaseAgent
|
|
|
|
|
|
from agents.agent_profiles import get_agent_profile
|
|
|
|
|
|
from utils.llm_client import LLMClient
|
|
|
|
|
|
import config
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class DebateConfig:
|
|
|
|
|
|
"""辩论配置"""
|
|
|
|
|
|
topic: str
|
|
|
|
|
|
context: str = ""
|
|
|
|
|
|
agent_ids: List[str] = None
|
|
|
|
|
|
max_rounds: int = 2
|
2026-01-07 12:59:56 +08:00
|
|
|
|
agent_clients: dict = None # Map[agent_id, LLMClient]
|
2026-01-07 14:04:52 +08:00
|
|
|
|
language: str = "Chinese"
|
2026-01-07 11:02:05 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
|
class SpeechRecord:
|
|
|
|
|
|
"""发言记录"""
|
|
|
|
|
|
agent_id: str
|
|
|
|
|
|
agent_name: str
|
|
|
|
|
|
emoji: str
|
|
|
|
|
|
content: str
|
|
|
|
|
|
round_num: int
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DebateManager:
|
|
|
|
|
|
"""辩论管理器"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, llm_client: LLMClient = None):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化辩论管理器
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
llm_client: LLM 客户端实例
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.llm_client = llm_client or LLMClient()
|
|
|
|
|
|
self.agents: List[BaseAgent] = []
|
|
|
|
|
|
self.speech_records: List[SpeechRecord] = []
|
|
|
|
|
|
self.current_round = 0
|
|
|
|
|
|
|
|
|
|
|
|
def setup_debate(self, debate_config: DebateConfig) -> None:
|
|
|
|
|
|
"""
|
|
|
|
|
|
设置辩论
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
debate_config: 辩论配置
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.config = debate_config
|
|
|
|
|
|
self.agents = []
|
|
|
|
|
|
self.speech_records = []
|
|
|
|
|
|
self.current_round = 0
|
|
|
|
|
|
|
|
|
|
|
|
# 创建参与的 Agent
|
|
|
|
|
|
for agent_id in debate_config.agent_ids:
|
2026-01-07 12:59:56 +08:00
|
|
|
|
# Check if specific client is provided in config, else use default
|
|
|
|
|
|
client = self.llm_client
|
|
|
|
|
|
if hasattr(debate_config, 'agent_clients') and debate_config.agent_clients and agent_id in debate_config.agent_clients:
|
|
|
|
|
|
client = debate_config.agent_clients[agent_id]
|
|
|
|
|
|
|
2026-01-07 14:04:52 +08:00
|
|
|
|
agent = BaseAgent(agent_id, client, language=debate_config.language)
|
2026-01-07 11:02:05 +08:00
|
|
|
|
self.agents.append(agent)
|
|
|
|
|
|
|
|
|
|
|
|
def run_debate_stream(
|
|
|
|
|
|
self,
|
|
|
|
|
|
on_speech_start: Callable = None,
|
|
|
|
|
|
on_speech_chunk: Callable = None,
|
|
|
|
|
|
on_speech_end: Callable = None,
|
|
|
|
|
|
on_round_end: Callable = None
|
|
|
|
|
|
) -> Generator[dict, None, None]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
运行辩论(流式)
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
on_speech_start: 发言开始回调
|
|
|
|
|
|
on_speech_chunk: 发言片段回调
|
|
|
|
|
|
on_speech_end: 发言结束回调
|
|
|
|
|
|
on_round_end: 轮次结束回调
|
|
|
|
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
|
|
dict: 事件信息
|
|
|
|
|
|
"""
|
|
|
|
|
|
for round_num in range(1, self.config.max_rounds + 1):
|
|
|
|
|
|
self.current_round = round_num
|
|
|
|
|
|
|
|
|
|
|
|
yield {
|
|
|
|
|
|
"type": "round_start",
|
|
|
|
|
|
"round": round_num,
|
|
|
|
|
|
"total_rounds": self.config.max_rounds
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for agent in self.agents:
|
|
|
|
|
|
# 获取之前的发言(排除自己)
|
|
|
|
|
|
previous_speeches = [
|
|
|
|
|
|
{
|
|
|
|
|
|
"name": r.agent_name,
|
|
|
|
|
|
"emoji": r.emoji,
|
|
|
|
|
|
"content": r.content
|
|
|
|
|
|
}
|
|
|
|
|
|
for r in self.speech_records
|
|
|
|
|
|
if r.agent_id != agent.agent_id
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
yield {
|
|
|
|
|
|
"type": "speech_start",
|
|
|
|
|
|
"agent_id": agent.agent_id,
|
|
|
|
|
|
"agent_name": agent.name,
|
|
|
|
|
|
"emoji": agent.emoji,
|
2026-01-07 12:59:56 +08:00
|
|
|
|
"model_name": agent.model_name,
|
2026-01-07 11:02:05 +08:00
|
|
|
|
"round": round_num
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 流式生成发言
|
|
|
|
|
|
full_content = ""
|
|
|
|
|
|
for chunk in agent.generate_response(
|
|
|
|
|
|
topic=self.config.topic,
|
|
|
|
|
|
context=self.config.context,
|
|
|
|
|
|
previous_speeches=previous_speeches,
|
|
|
|
|
|
round_num=round_num
|
|
|
|
|
|
):
|
|
|
|
|
|
full_content += chunk
|
|
|
|
|
|
yield {
|
|
|
|
|
|
"type": "speech_chunk",
|
|
|
|
|
|
"agent_id": agent.agent_id,
|
|
|
|
|
|
"chunk": chunk
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
# 保存发言记录
|
|
|
|
|
|
record = SpeechRecord(
|
|
|
|
|
|
agent_id=agent.agent_id,
|
|
|
|
|
|
agent_name=agent.name,
|
|
|
|
|
|
emoji=agent.emoji,
|
|
|
|
|
|
content=full_content,
|
|
|
|
|
|
round_num=round_num
|
|
|
|
|
|
)
|
|
|
|
|
|
self.speech_records.append(record)
|
|
|
|
|
|
|
|
|
|
|
|
yield {
|
|
|
|
|
|
"type": "speech_end",
|
|
|
|
|
|
"agent_id": agent.agent_id,
|
|
|
|
|
|
"content": full_content
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
yield {
|
|
|
|
|
|
"type": "round_end",
|
|
|
|
|
|
"round": round_num
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
yield {"type": "debate_end"}
|
|
|
|
|
|
|
|
|
|
|
|
def get_all_speeches(self) -> List[SpeechRecord]:
|
|
|
|
|
|
"""获取所有发言记录"""
|
|
|
|
|
|
return self.speech_records
|
|
|
|
|
|
|
|
|
|
|
|
def get_speeches_by_round(self, round_num: int) -> List[SpeechRecord]:
|
|
|
|
|
|
"""获取指定轮次的发言"""
|
|
|
|
|
|
return [r for r in self.speech_records if r.round_num == round_num]
|
|
|
|
|
|
|
|
|
|
|
|
def get_speeches_by_agent(self, agent_id: str) -> List[SpeechRecord]:
|
|
|
|
|
|
"""获取指定 Agent 的所有发言"""
|
|
|
|
|
|
return [r for r in self.speech_records if r.agent_id == agent_id]
|