335 lines
11 KiB
Python
335 lines
11 KiB
Python
import os
|
||
import sqlite3
|
||
import json
|
||
from flask import Flask, render_template, request, jsonify
|
||
from werkzeug.utils import secure_filename
|
||
import uuid
|
||
from datetime import datetime
|
||
from dotenv import load_dotenv
|
||
from openai import OpenAI
|
||
|
||
load_dotenv()
|
||
|
||
app = Flask(__name__)
|
||
app.config['UPLOAD_FOLDER'] = 'uploads'
|
||
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
|
||
app.config['DATABASE'] = 'knowledge_base.db'
|
||
|
||
DEEPSEEK_API_KEY = os.getenv('DEEPSEEK_API_KEY')
|
||
DEEPSEEK_BASE_URL = os.getenv('DEEPSEEK_BASE_URL', 'https://api.deepseek.com')
|
||
|
||
client = OpenAI(
|
||
api_key=DEEPSEEK_API_KEY,
|
||
base_url=DEEPSEEK_BASE_URL
|
||
)
|
||
|
||
ALLOWED_EXTENSIONS = {'txt', 'pdf', 'docx'}
|
||
|
||
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
||
|
||
documents = {}
|
||
|
||
def init_db():
|
||
conn = sqlite3.connect(app.config['DATABASE'])
|
||
cursor = conn.cursor()
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS conversations (
|
||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||
question TEXT NOT NULL,
|
||
answer TEXT NOT NULL,
|
||
sources TEXT,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
''')
|
||
cursor.execute('''
|
||
CREATE TABLE IF NOT EXISTS documents (
|
||
id TEXT PRIMARY KEY,
|
||
name TEXT NOT NULL,
|
||
status TEXT NOT NULL,
|
||
chunks INTEGER DEFAULT 0,
|
||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||
)
|
||
''')
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
def get_db_connection():
|
||
conn = sqlite3.connect(app.config['DATABASE'])
|
||
conn.row_factory = sqlite3.Row
|
||
return conn
|
||
|
||
def load_documents_from_db():
|
||
conn = get_db_connection()
|
||
docs = conn.execute('SELECT * FROM documents ORDER BY created_at DESC').fetchall()
|
||
conn.close()
|
||
|
||
global documents
|
||
documents = {doc['id']: dict(doc) for doc in docs}
|
||
|
||
def allowed_file(filename):
|
||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||
|
||
def read_document_content(doc_id):
|
||
try:
|
||
for file in os.listdir(app.config['UPLOAD_FOLDER']):
|
||
if file.startswith(doc_id):
|
||
filepath = os.path.join(app.config['UPLOAD_FOLDER'], file)
|
||
|
||
# 根据文件扩展名判断类型
|
||
if file.lower().endswith('.txt'):
|
||
with open(filepath, 'r', encoding='utf-8') as f:
|
||
return f.read()
|
||
|
||
elif file.lower().endswith('.pdf'):
|
||
import pypdf
|
||
with open(filepath, 'rb') as f:
|
||
reader = pypdf.PdfReader(f)
|
||
text = ''
|
||
for page in reader.pages:
|
||
text += page.extract_text() + '\n'
|
||
return text
|
||
|
||
elif file.lower().endswith('.docx'):
|
||
from docx import Document
|
||
doc = Document(filepath)
|
||
text = ''
|
||
for paragraph in doc.paragraphs:
|
||
text += paragraph.text + '\n'
|
||
return text
|
||
|
||
# 如果没有扩展名,尝试按顺序尝试不同格式
|
||
else:
|
||
# 先尝试作为 docx 文件
|
||
try:
|
||
from docx import Document
|
||
doc = Document(filepath)
|
||
text = ''
|
||
for paragraph in doc.paragraphs:
|
||
text += paragraph.text + '\n'
|
||
if text.strip():
|
||
return text
|
||
except:
|
||
pass
|
||
|
||
# 再尝试作为 txt 文件
|
||
try:
|
||
with open(filepath, 'r', encoding='utf-8') as f:
|
||
text = f.read()
|
||
if text.strip():
|
||
return text
|
||
except:
|
||
pass
|
||
|
||
# 最后尝试作为 pdf 文件
|
||
try:
|
||
import pypdf
|
||
with open(filepath, 'rb') as f:
|
||
reader = pypdf.PdfReader(f)
|
||
text = ''
|
||
for page in reader.pages:
|
||
text += page.extract_text() + '\n'
|
||
if text.strip():
|
||
return text
|
||
except:
|
||
pass
|
||
|
||
return None
|
||
except Exception as e:
|
||
print(f"Error reading document: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None
|
||
|
||
@app.route('/')
|
||
def index():
|
||
load_documents_from_db()
|
||
return render_template('index.html')
|
||
|
||
@app.route('/api/upload', methods=['POST'])
|
||
def upload_document():
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({'error': '没有文件'}), 400
|
||
|
||
file = request.files['file']
|
||
if file.filename == '':
|
||
return jsonify({'error': '没有选择文件'}), 400
|
||
|
||
if file and allowed_file(file.filename):
|
||
doc_id = str(uuid.uuid4())
|
||
filename = secure_filename(file.filename)
|
||
filepath = os.path.join(app.config['UPLOAD_FOLDER'], f"{doc_id}_{filename}")
|
||
file.save(filepath)
|
||
|
||
conn = get_db_connection()
|
||
conn.execute(
|
||
'INSERT INTO documents (id, name, status, chunks) VALUES (?, ?, ?, ?)',
|
||
(doc_id, filename, 'completed', 1)
|
||
)
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
load_documents_from_db()
|
||
|
||
return jsonify({
|
||
'id': doc_id,
|
||
'name': filename,
|
||
'status': 'completed'
|
||
})
|
||
|
||
return jsonify({'error': '不支持的文件格式'}), 400
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'上传失败:{str(e)}'}), 500
|
||
|
||
@app.route('/api/ask', methods=['POST'])
|
||
def ask_question():
|
||
try:
|
||
data = request.json
|
||
question = data.get('question', '')
|
||
|
||
if not question or not question.strip():
|
||
return jsonify({'error': '请输入问题'}), 400
|
||
|
||
if len(question) > 1000:
|
||
return jsonify({'error': '问题长度不能超过1000字'}), 400
|
||
|
||
load_documents_from_db()
|
||
|
||
if not documents:
|
||
return jsonify({'error': '请先上传文档'}), 400
|
||
|
||
context_parts = []
|
||
sources = []
|
||
|
||
for doc_id, doc_info in documents.items():
|
||
if doc_info['status'] == 'completed':
|
||
content = read_document_content(doc_id)
|
||
if content:
|
||
context_parts.append(f"文档:{doc_info['name']}\n内容:{content[:3000]}")
|
||
sources.append({
|
||
'doc_id': doc_id,
|
||
'name': doc_info['name'],
|
||
'page': 1
|
||
})
|
||
|
||
if not context_parts:
|
||
return jsonify({'error': '没有可用的文档内容'}), 400
|
||
|
||
context = '\n\n'.join(context_parts)
|
||
|
||
system_prompt = """你是一个智能知识库问答助手。请基于提供的文档内容回答用户的问题。
|
||
要求:
|
||
1. 只使用文档中的信息回答问题
|
||
2. 如果文档中没有相关信息,请明确说明
|
||
3. 回答要准确、简洁、有条理
|
||
4. 使用中文回答"""
|
||
|
||
user_prompt = f"""文档内容:
|
||
{context}
|
||
|
||
用户问题:{question}
|
||
|
||
请基于以上文档内容回答用户的问题。"""
|
||
|
||
try:
|
||
response = client.chat.completions.create(
|
||
model="deepseek-chat",
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
],
|
||
temperature=0.7,
|
||
max_tokens=2000
|
||
)
|
||
|
||
answer = response.choices[0].message.content
|
||
|
||
result = {
|
||
'question': question,
|
||
'answer': answer,
|
||
'sources': sources
|
||
}
|
||
|
||
conn = get_db_connection()
|
||
conn.execute(
|
||
'INSERT INTO conversations (question, answer, sources) VALUES (?, ?, ?)',
|
||
(question, result['answer'], json.dumps(result['sources']))
|
||
)
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
return jsonify(result)
|
||
|
||
except Exception as api_error:
|
||
print(f"DeepSeek API Error: {api_error}")
|
||
return jsonify({'error': f'AI服务暂时不可用:{str(api_error)}'}), 500
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'回答问题时出错:{str(e)}'}), 500
|
||
|
||
@app.route('/api/documents', methods=['GET'])
|
||
def get_documents():
|
||
try:
|
||
load_documents_from_db()
|
||
return jsonify(list(documents.values()))
|
||
except Exception as e:
|
||
return jsonify({'error': f'获取文档列表失败:{str(e)}'}), 500
|
||
|
||
@app.route('/api/documents/<doc_id>', methods=['DELETE'])
|
||
def delete_document(doc_id):
|
||
try:
|
||
conn = get_db_connection()
|
||
cursor = conn.execute('DELETE FROM documents WHERE id = ?', (doc_id,))
|
||
|
||
if cursor.rowcount == 0:
|
||
conn.close()
|
||
return jsonify({'error': '文档不存在'}), 404
|
||
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
load_documents_from_db()
|
||
return jsonify({'success': True})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'删除文档失败:{str(e)}'}), 500
|
||
|
||
@app.route('/api/conversations', methods=['GET'])
|
||
def get_conversations():
|
||
try:
|
||
conn = get_db_connection()
|
||
conversations = conn.execute(
|
||
'SELECT * FROM conversations ORDER BY created_at DESC LIMIT 50'
|
||
).fetchall()
|
||
conn.close()
|
||
|
||
result = []
|
||
for conv in conversations:
|
||
conv_dict = dict(conv)
|
||
conv_dict['sources'] = json.loads(conv_dict['sources']) if conv_dict['sources'] else []
|
||
result.append(conv_dict)
|
||
|
||
return jsonify(result)
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'获取对话历史失败:{str(e)}'}), 500
|
||
|
||
@app.route('/api/conversations', methods=['DELETE'])
|
||
def clear_conversations():
|
||
try:
|
||
conn = get_db_connection()
|
||
conn.execute('DELETE FROM conversations')
|
||
conn.commit()
|
||
conn.close()
|
||
|
||
return jsonify({'success': True})
|
||
|
||
except Exception as e:
|
||
return jsonify({'error': f'清除对话历史失败:{str(e)}'}), 500
|
||
|
||
if __name__ == '__main__':
|
||
init_db()
|
||
load_documents_from_db()
|
||
app.run(debug=True, port=5000)
|