删除 check_environment.py
This commit is contained in:
parent
b8fca412cb
commit
0802754a05
@ -1,193 +0,0 @@
|
||||
"""
|
||||
环境检查脚本
|
||||
用于验证项目依赖和环境配置是否正确
|
||||
"""
|
||||
|
||||
import sys
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
def check_python_version():
|
||||
"""检查Python版本"""
|
||||
print("=" * 60)
|
||||
print("检查 Python 版本...")
|
||||
print("=" * 60)
|
||||
|
||||
version = sys.version_info
|
||||
print(f"当前 Python 版本: {version.major}.{version.minor}.{version.micro}")
|
||||
|
||||
if version.major == 3 and version.minor >= 10:
|
||||
print("✓ Python 版本符合要求 (>= 3.10)")
|
||||
return True
|
||||
else:
|
||||
print("✗ Python 版本不符合要求,需要 3.10 或更高版本")
|
||||
return False
|
||||
|
||||
def check_dependencies():
|
||||
"""检查必要的依赖包"""
|
||||
print("\n" + "=" * 60)
|
||||
print("检查依赖包...")
|
||||
print("=" * 60)
|
||||
|
||||
required_packages = {
|
||||
'numpy': 'numpy',
|
||||
'polars': 'polars',
|
||||
'sklearn': 'scikit-learn',
|
||||
'imblearn': 'imbalanced-learn',
|
||||
'matplotlib': 'matplotlib',
|
||||
'seaborn': 'seaborn',
|
||||
'joblib': 'joblib',
|
||||
'pydantic': 'pydantic',
|
||||
'streamlit': 'streamlit',
|
||||
}
|
||||
|
||||
missing_packages = []
|
||||
|
||||
for module_name, package_name in required_packages.items():
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
version = getattr(module, '__version__', 'unknown')
|
||||
print(f"✓ {package_name:20s} - 版本: {version}")
|
||||
except ImportError:
|
||||
print(f"✗ {package_name:20s} - 未安装")
|
||||
missing_packages.append(package_name)
|
||||
|
||||
if missing_packages:
|
||||
print(f"\n缺少 {len(missing_packages)} 个依赖包")
|
||||
print(f"请运行: pip install {' '.join(missing_packages)}")
|
||||
return False
|
||||
else:
|
||||
print("\n✓ 所有依赖包已正确安装")
|
||||
return True
|
||||
|
||||
def check_data_files():
|
||||
"""检查数据文件"""
|
||||
print("\n" + "=" * 60)
|
||||
print("检查数据文件...")
|
||||
print("=" * 60)
|
||||
|
||||
data_dir = Path("data")
|
||||
creditcard_csv = data_dir / "creditcard.csv"
|
||||
|
||||
if creditcard_csv.exists():
|
||||
file_size = creditcard_csv.stat().st_size / (1024 * 1024) # MB
|
||||
print(f"✓ data/creditcard.csv 存在 (大小: {file_size:.2f} MB)")
|
||||
return True
|
||||
else:
|
||||
print("✗ data/creditcard.csv 不存在")
|
||||
print("请从以下地址下载数据集:")
|
||||
print("https://www.kaggle.com/mlg-ulb/creditcardfraud")
|
||||
print("并将 creditcard.csv 文件放入 data/ 目录")
|
||||
return False
|
||||
|
||||
def check_model_files():
|
||||
"""检查模型文件"""
|
||||
print("\n" + "=" * 60)
|
||||
print("检查模型文件...")
|
||||
print("=" * 60)
|
||||
|
||||
models_dir = Path("models")
|
||||
required_models = [
|
||||
"random_forest_model.joblib",
|
||||
"logistic_regression_model.joblib",
|
||||
"scaler.joblib"
|
||||
]
|
||||
|
||||
missing_models = []
|
||||
|
||||
for model_file in required_models:
|
||||
model_path = models_dir / model_file
|
||||
if model_path.exists():
|
||||
file_size = model_path.stat().st_size / 1024 # KB
|
||||
print(f"✓ {model_file:35s} (大小: {file_size:.2f} KB)")
|
||||
else:
|
||||
print(f"✗ {model_file:35s} - 不存在")
|
||||
missing_models.append(model_file)
|
||||
|
||||
if missing_models:
|
||||
print(f"\n缺少 {len(missing_models)} 个模型文件")
|
||||
print("请运行: python src/train.py 来训练模型")
|
||||
return False
|
||||
else:
|
||||
print("\n✓ 所有模型文件已存在")
|
||||
return True
|
||||
|
||||
def check_source_files():
|
||||
"""检查源代码文件"""
|
||||
print("\n" + "=" * 60)
|
||||
print("检查源代码文件...")
|
||||
print("=" * 60)
|
||||
|
||||
src_dir = Path("src")
|
||||
required_files = [
|
||||
"__init__.py",
|
||||
"data.py",
|
||||
"features.py",
|
||||
"train.py",
|
||||
"infer.py",
|
||||
"agent_app.py",
|
||||
"streamlit_app.py"
|
||||
]
|
||||
|
||||
missing_files = []
|
||||
|
||||
for file_name in required_files:
|
||||
file_path = src_dir / file_name
|
||||
if file_path.exists():
|
||||
print(f"✓ src/{file_name}")
|
||||
else:
|
||||
print(f"✗ src/{file_name} - 不存在")
|
||||
missing_files.append(file_name)
|
||||
|
||||
if missing_files:
|
||||
print(f"\n缺少 {len(missing_files)} 个源代码文件")
|
||||
return False
|
||||
else:
|
||||
print("\n✓ 所有源代码文件完整")
|
||||
return True
|
||||
|
||||
def run_all_checks():
|
||||
"""运行所有检查"""
|
||||
print("\n" + "=" * 60)
|
||||
print("信用卡欺诈检测系统 - 环境检查")
|
||||
print("=" * 60)
|
||||
|
||||
results = {
|
||||
"Python 版本": check_python_version(),
|
||||
"依赖包": check_dependencies(),
|
||||
"数据文件": check_data_files(),
|
||||
"模型文件": check_model_files(),
|
||||
"源代码文件": check_source_files(),
|
||||
}
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("检查结果汇总")
|
||||
print("=" * 60)
|
||||
|
||||
for check_name, result in results.items():
|
||||
status = "✓ 通过" if result else "✗ 失败"
|
||||
print(f"{check_name:15s}: {status}")
|
||||
|
||||
all_passed = all(results.values())
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
if all_passed:
|
||||
print("✓ 所有检查通过!您可以运行系统了")
|
||||
print("\n运行命令:")
|
||||
print(" python src/agent_app.py")
|
||||
else:
|
||||
print("✗ 部分检查未通过,请根据上述提示解决问题")
|
||||
print("\n快速修复:")
|
||||
if not results["依赖包"]:
|
||||
print(" 1. 安装依赖: pip install -r requirements.txt")
|
||||
if not results["数据文件"]:
|
||||
print(" 2. 下载数据: 从 Kaggle 下载 creditcard.csv 到 data/ 目录")
|
||||
if not results["模型文件"]:
|
||||
print(" 3. 训练模型: python src/train.py")
|
||||
print("=" * 60)
|
||||
|
||||
return all_passed
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = run_all_checks()
|
||||
sys.exit(0 if success else 1)
|
||||
Loading…
Reference in New Issue
Block a user