2311061111-lyt/autograde/llm_grade.py
liyitian 83cd133001
All checks were successful
autograde-final-vibevault / check-trigger (push) Successful in 13s
autograde-final-vibevault / grade (push) Has been skipped
add autograde tests
2025-12-14 18:19:41 +08:00

211 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
llm_grade.py - 使用LLM对报告进行评分
"""
import json
import argparse
import os
import requests
import time
from typing import Dict, Any
def parse_args():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='LLM Report Grading Script')
parser.add_argument('--question', required=True, help='评分问题描述')
parser.add_argument('--answer', required=True, help='待评分的答案文件')
parser.add_argument('--rubric', required=True, help='评分标准文件')
parser.add_argument('--out', required=True, help='输出评分结果文件')
parser.add_argument('--summary', required=True, help='输出评分摘要文件')
return parser.parse_args()
def load_file_content(file_path: str) -> str:
"""加载文件内容"""
with open(file_path, 'r', encoding='utf-8') as f:
return f.read()
def load_rubric(rubric_path: str) -> Dict[str, Any]:
"""加载评分标准"""
with open(rubric_path, 'r', encoding='utf-8') as f:
return json.load(f)
def call_llm_api(prompt: str, max_retries: int = 3, timeout: int = 30) -> str:
"""调用LLM API"""
# 获取环境变量中的API配置
api_key = os.environ.get('LLM_API_KEY', '')
api_url = os.environ.get('LLM_API_URL', 'http://localhost:11434/api/generate')
model = os.environ.get('LLM_MODEL', 'llama3')
headers = {
'Content-Type': 'application/json',
}
if api_key:
headers['Authorization'] = f'Bearer {api_key}'
payload = {
'model': model,
'prompt': prompt,
'stream': False,
'temperature': 0.3
}
for attempt in range(max_retries):
try:
response = requests.post(api_url, json=payload, headers=headers, timeout=timeout)
response.raise_for_status()
result = response.json()
return result.get('response', '')
except requests.exceptions.RequestException as e:
print(f"⚠️ LLM API调用失败 (尝试 {attempt + 1}/{max_retries}): {e}")
if attempt < max_retries - 1:
time.sleep(2 ** attempt) # 指数退避
else:
raise
def generate_grading_prompt(question: str, answer: str, rubric: Dict[str, Any]) -> str:
"""生成评分提示词"""
prompt = f"""你是一位专业的课程作业评分专家。请根据以下评分标准,对学生的作业进行客观、公正的评分。
## 评分问题
{question}
## 学生答案
{answer}
## 评分标准
{json.dumps(rubric, ensure_ascii=False, indent=2)}
## 评分要求
1. 严格按照评分标准进行评分,每个评分项给出具体得分
2. 详细说明每个评分项的得分理由
3. 给出总体评价和建议
4. 最终输出必须包含JSON格式的评分结果格式如下
```json
{
"total": 总分,
"scores": {
"评分项1": 得分,
"评分项2": 得分,
...
},
"feedback": "详细的评分反馈和建议"
}
```
请确保输出格式正确只包含上述JSON格式内容不要添加任何其他说明。"""
return prompt
def parse_llm_response(response: str) -> Dict[str, Any]:
"""解析LLM响应"""
# 提取JSON部分
import re
json_match = re.search(r'```json\n(.*?)\n```', response, re.DOTALL)
if json_match:
json_str = json_match.group(1)
try:
return json.loads(json_str)
except json.JSONDecodeError:
print("⚠️ LLM响应中的JSON格式错误")
# 尝试直接解析响应
try:
return json.loads(response)
except json.JSONDecodeError:
print("⚠️ LLM响应不是有效的JSON格式")
# 如果都失败,返回默认值
return {
'total': 0.0,
'scores': {},
'feedback': '评分失败无法解析LLM响应'
}
def generate_summary(grade_result: Dict[str, Any], rubric: Dict[str, Any]) -> str:
"""生成评分摘要"""
summary = "# LLM评分报告\n\n"
summary += f"## 总体评价\n"
summary += f"- 最终得分: {grade_result['total']:.2f}/{sum(rubric.get('criteria', {}).values())}\n"
summary += f"- 评分时间: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())}\n\n"
summary += f"## 评分详情\n"
summary += "| 评分项 | 得分 | 满分 | 评分标准 |\n"
summary += "|-------|------|------|---------|\n"
for criterion, full_score in rubric.get('criteria', {}).items():
score = grade_result['scores'].get(criterion, 0.0)
summary += f"| {criterion} | {score:.2f} | {full_score} | {rubric.get('descriptions', {}).get(criterion, '')} |\n"
summary += "\n"
summary += f"## 详细反馈\n"
summary += grade_result['feedback'] + "\n"
return summary
def main():
"""主函数"""
args = parse_args()
# 加载文件内容
print(f"📁 加载待评分文件: {args.answer}")
answer_content = load_file_content(args.answer)
# 加载评分标准
print(f"📋 加载评分标准: {args.rubric}")
rubric = load_rubric(args.rubric)
# 生成评分提示词
print("📝 生成评分提示词...")
prompt = generate_grading_prompt(args.question, answer_content, rubric)
# 调用LLM API
print("🤖 调用LLM进行评分...")
try:
llm_response = call_llm_api(prompt)
print("✅ LLM API调用成功")
except Exception as e:
print(f"❌ LLM API调用失败: {e}")
# 返回默认评分结果
grade_result = {
'total': 0.0,
'scores': {criterion: 0.0 for criterion in rubric.get('criteria', {})},
'feedback': f'评分失败LLM API调用错误 - {str(e)}'
}
else:
# 解析LLM响应
print("📊 解析LLM评分结果...")
grade_result = parse_llm_response(llm_response)
# 保存评分结果
print(f"💾 保存评分结果: {args.out}")
with open(args.out, 'w', encoding='utf-8') as f:
json.dump(grade_result, f, ensure_ascii=False, indent=2)
# 生成评分摘要
print(f"📝 生成评分摘要: {args.summary}")
summary = generate_summary(grade_result, rubric)
with open(args.summary, 'w', encoding='utf-8') as f:
f.write(summary)
print(f"✅ 评分完成! 最终得分: {grade_result['total']:.2f}")
if __name__ == '__main__':
main()