631 lines
21 KiB
Python
631 lines
21 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
数据质量检查脚本
|
||
===================
|
||
对 Static/ 下的数据目录执行完整性、统计、平衡性与离群值检查,
|
||
生成详细报告输出到终端,并可保存为文本文件。
|
||
|
||
用法:
|
||
python Scripts/check_data.py --folder 20260319Numbers --labels 0 1 2 3 4 5 6 7 8 9
|
||
python Scripts/check_data.py --folder "20260408 grap" --labels 1 2 3 4 5 6 7 8 9 --output report.txt
|
||
python Scripts/check_data.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
|
||
|
||
要求:
|
||
在项目根目录 (Deeplearning/) 下运行,或通过 --root 指定项目根目录。
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import unicodedata
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
|
||
|
||
# ============================================================
|
||
# 工具函数(与 loadData.py 中逻辑保持一致)
|
||
# ============================================================
|
||
|
||
DEFAULT_FILE_CLASSES = ("xlsx", "xls", "csv")
|
||
|
||
|
||
def _has_supported_extension(filename: str, file_classes=DEFAULT_FILE_CLASSES) -> bool:
|
||
ext = os.path.splitext(filename)[1].lower().lstrip(".")
|
||
return ext in file_classes
|
||
|
||
|
||
def _read_data_file(file_path: str) -> pd.DataFrame:
|
||
ext = os.path.splitext(file_path)[1].lower()
|
||
if ext == ".csv":
|
||
return pd.read_csv(file_path)
|
||
if ext in (".xls", ".xlsx"):
|
||
return pd.read_excel(file_path)
|
||
raise ValueError(
|
||
f"Unsupported file format: {ext}. Only .xls, .xlsx, and .csv are supported. "
|
||
f"File: {file_path}"
|
||
)
|
||
|
||
|
||
def _strip_zero_width(s: str) -> str:
|
||
if not isinstance(s, str):
|
||
return s
|
||
return s.translate(
|
||
{0x200B: None, 0x200C: None, 0x200D: None, 0xFEFF: None}
|
||
)
|
||
|
||
|
||
def _canonicalize_name(name: str) -> str:
|
||
name = unicodedata.normalize("NFKC", name)
|
||
name = _strip_zero_width(name)
|
||
return name
|
||
|
||
|
||
def _normalize_for_compare(name: str) -> str:
|
||
n = _canonicalize_name(name)
|
||
n = n.replace("_", " ")
|
||
n = " ".join(n.split())
|
||
return n.lower()
|
||
|
||
|
||
def _find_matching_file(folder: str, expected_name: str):
|
||
expected = _canonicalize_name(expected_name)
|
||
try:
|
||
entries = os.listdir(folder)
|
||
except FileNotFoundError:
|
||
return None
|
||
for f in entries:
|
||
if _canonicalize_name(f) == expected:
|
||
return f
|
||
expected_lower = expected.lower()
|
||
for f in entries:
|
||
if _canonicalize_name(f).lower() == expected_lower:
|
||
return f
|
||
expected_relaxed = _normalize_for_compare(expected_name)
|
||
for f in entries:
|
||
if _normalize_for_compare(f) == expected_relaxed:
|
||
return f
|
||
return None
|
||
|
||
|
||
def _find_matching_file_by_label(folder: str, label_name, file_classes):
|
||
for ext in file_classes:
|
||
expected_name = f"{label_name}.{ext}"
|
||
match = _find_matching_file(folder, expected_name)
|
||
if match is not None:
|
||
return match
|
||
return None
|
||
|
||
|
||
# ============================================================
|
||
# 报告生成工具
|
||
# ============================================================
|
||
|
||
class ReportBuffer:
|
||
"""收集报告行,并同时输出到 stdout 和文件。"""
|
||
|
||
def __init__(self, output_path=None):
|
||
self.lines: list[str] = []
|
||
self.output_path = output_path
|
||
|
||
def add(self, text: str = ""):
|
||
print(text)
|
||
self.lines.append(text)
|
||
|
||
def save(self):
|
||
if self.output_path:
|
||
with open(self.output_path, "w", encoding="utf-8") as f:
|
||
f.write("\n".join(self.lines) + "\n")
|
||
print(f"\n报告已保存到: {self.output_path}")
|
||
|
||
|
||
# ============================================================
|
||
# 核心检查逻辑
|
||
# ============================================================
|
||
|
||
def _extract_features(df: pd.DataFrame, source: str) -> pd.DataFrame:
|
||
"""
|
||
按项目约定提取偶数列作为特征(保持 int 列名以对齐)。
|
||
返回特征 DataFrame(列名 0, 2, 4, ...)。
|
||
"""
|
||
# 偶数列索引: 1, 3, 5, ...
|
||
even_cols = [c for i, c in enumerate(df.columns) if i % 2 == 1]
|
||
if not even_cols:
|
||
raise ValueError(f"没有找到偶数列(特征列)。请检查文件: {source}")
|
||
features = df[even_cols].copy()
|
||
|
||
# 尝试转为数值
|
||
for c in features.columns:
|
||
features[c] = pd.to_numeric(features[c], errors="coerce")
|
||
|
||
return features
|
||
|
||
|
||
def check_tabular_project(root: str, folder: str, labels: list[str], rp: ReportBuffer):
|
||
"""完整检查流程"""
|
||
data_dir = os.path.join(root, "Static", folder)
|
||
if not os.path.isdir(data_dir):
|
||
rp.add(f"[ERROR] 目录不存在: {data_dir}")
|
||
rp.add("请确认 --folder 参数正确。")
|
||
return
|
||
|
||
rp.add("=" * 64)
|
||
rp.add(" Deeplearning 数据质量检查报告")
|
||
rp.add("=" * 64)
|
||
rp.add(f" 数据目录 : {data_dir}")
|
||
rp.add(f" 标签数量 : {len(labels)}")
|
||
rp.add(f" 标签列表 : {labels}")
|
||
rp.add()
|
||
|
||
# ---- 第一步:检测数据模式 ----
|
||
has_all_subfolders = True
|
||
for lbl in labels:
|
||
sub = os.path.join(data_dir, str(lbl))
|
||
if not (os.path.isdir(sub) and any(_has_supported_extension(f) for f in os.listdir(sub))):
|
||
has_all_subfolders = False
|
||
break
|
||
|
||
has_all_files = True
|
||
for lbl in labels:
|
||
if _find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES) is None:
|
||
has_all_files = False
|
||
break
|
||
|
||
if has_all_files and not has_all_subfolders:
|
||
mode = "single_file"
|
||
elif has_all_subfolders and not has_all_files:
|
||
mode = "multi_folder"
|
||
else:
|
||
rp.add("[ERROR] 无法自动检测数据模式,或两种模式同时存在。")
|
||
rp.add(f" has_all_files : {has_all_files}")
|
||
rp.add(f" has_all_subfolders: {has_all_subfolders}")
|
||
rp.add("请确保每个 label 对应唯一的文件或唯一的子目录。")
|
||
return
|
||
|
||
if mode == "single_file":
|
||
_check_single_file_mode(data_dir, labels, rp)
|
||
else:
|
||
_check_multi_folder_mode(data_dir, labels, rp)
|
||
|
||
rp.add()
|
||
rp.add("=" * 64)
|
||
rp.add(" 检查完成。")
|
||
rp.add("=" * 64)
|
||
|
||
|
||
def _check_single_file_mode(data_dir: str, labels: list[str], rp: ReportBuffer):
|
||
rp.add()
|
||
rp.add("── 数据模式: 单文件模式 ──")
|
||
rp.add()
|
||
|
||
# 1. 定位实际文件名
|
||
file_map: dict[str, str] = {}
|
||
missing: list[str] = []
|
||
for lbl in labels:
|
||
match = _find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES)
|
||
if match:
|
||
file_map[lbl] = match
|
||
else:
|
||
missing.append(lbl)
|
||
|
||
if missing:
|
||
rp.add(f"[WARN] 以下标签找不到对应文件: {missing}")
|
||
rp.add(f"当前目录内容: {sorted(os.listdir(data_dir))}")
|
||
if not file_map:
|
||
return
|
||
labels = [l for l in labels if l in file_map]
|
||
|
||
# 2. 逐类读取
|
||
all_features = [] # list of (label, pd.DataFrame)
|
||
per_class_info: dict[str, dict] = {}
|
||
col_counts: dict[str, int] = {}
|
||
|
||
for lbl in labels:
|
||
fname = file_map[lbl]
|
||
file_path = os.path.join(data_dir, fname)
|
||
info: dict[str, object] = {"label": lbl, "file": fname, "warnings": []}
|
||
|
||
try:
|
||
raw = _read_data_file(file_path)
|
||
except Exception as e:
|
||
info["error"] = str(e)
|
||
per_class_info[lbl] = info
|
||
rp.add(f"[ERROR] 读取文件失败: {file_path} — {e}")
|
||
continue
|
||
|
||
info["raw_rows"] = raw.shape[0]
|
||
info["raw_cols"] = raw.shape[1]
|
||
|
||
# NaN 在原始文件中
|
||
total_nan = raw.isna().sum().sum()
|
||
if total_nan > 0:
|
||
info["warnings"].append(f"原始文件含 {total_nan} 个 NaN 单元格")
|
||
|
||
try:
|
||
features = _extract_features(raw, fname)
|
||
except ValueError as e:
|
||
info["error"] = str(e)
|
||
per_class_info[lbl] = info
|
||
rp.add(f"[ERROR] 特征提取失败: {file_path} — {e}")
|
||
continue
|
||
|
||
# 丢弃含 NaN 的行(同 loadData 的 dropna 逻辑)后统计
|
||
clean = features.dropna()
|
||
info["feature_cols"] = features.shape[1]
|
||
info["samples_after_dropna"] = clean.shape[0]
|
||
info["dropped_nan_rows"] = features.shape[0] - clean.shape[0]
|
||
info["values"] = clean.values
|
||
|
||
col_counts[lbl] = features.shape[1]
|
||
|
||
if clean.shape[0] == 0:
|
||
info["warnings"].append("去除 NaN 后无有效样本")
|
||
|
||
per_class_info[lbl] = info
|
||
if clean.shape[0] > 0:
|
||
all_features.append((lbl, clean))
|
||
|
||
# 列数一致性
|
||
if len(set(col_counts.values())) > 1:
|
||
rp.add()
|
||
rp.add("[WARN] 各标签的特征列数不一致!")
|
||
for lbl, cc in col_counts.items():
|
||
rp.add(f" {lbl}: {cc} 列")
|
||
rp.add("这会导致 load_data 时补零逻辑产生差异。")
|
||
else:
|
||
rp.add(f"[OK] 所有标签特征列数一致: {next(iter(col_counts.values()), 0)} 列")
|
||
|
||
# 样本数统计
|
||
rp.add()
|
||
rp.add("── 各类别样本数 ──")
|
||
sample_counts: dict[str, int] = {}
|
||
for lbl in labels:
|
||
info = per_class_info.get(lbl, {})
|
||
if "error" in info:
|
||
rp.add(f" [{lbl}] 加载失败: {info['error']}")
|
||
continue
|
||
n = info.get("samples_after_dropna", 0)
|
||
sample_counts[lbl] = n
|
||
warnings = info.get("warnings", [])
|
||
wflag = f" ⚠ {'; '.join(warnings)}" if warnings else ""
|
||
rp.add(f" [{lbl}] {n} 行 (文件: {info.get('file','?')}, "
|
||
f"原始 {info.get('raw_rows','?')} 行, "
|
||
f"丢弃 NaN 行 {info.get('dropped_nan_rows',0)}){wflag}")
|
||
|
||
# 平衡性分析
|
||
_analyze_balance(sample_counts, rp)
|
||
|
||
# 统计 + 离群值
|
||
_analyze_statistics(all_features, rp)
|
||
_analyze_outliers(all_features, rp)
|
||
|
||
|
||
def _check_multi_folder_mode(data_dir: str, labels: list[str], rp: ReportBuffer):
|
||
rp.add()
|
||
rp.add("── 数据模式: 多子特征模式 ──")
|
||
rp.add()
|
||
|
||
all_features = []
|
||
per_class_info: dict[str, dict] = {}
|
||
col_counts: dict[str, int] = {}
|
||
|
||
for lbl in labels:
|
||
sub = os.path.join(data_dir, str(lbl))
|
||
if not os.path.isdir(sub):
|
||
per_class_info[lbl] = {"error": f"子目录不存在: {sub}"}
|
||
rp.add(f"[ERROR] {lbl}: 子目录不存在")
|
||
continue
|
||
|
||
files = sorted(
|
||
[f for f in os.listdir(sub) if _has_supported_extension(f)]
|
||
)
|
||
if not files:
|
||
per_class_info[lbl] = {"error": f"子目录下无支持的文件: {sub}"}
|
||
rp.add(f"[ERROR] {lbl}: 子目录下无 .xlsx/.xls/.csv 文件")
|
||
continue
|
||
|
||
class_frame_list = []
|
||
single_file_cols = set()
|
||
total_raw = 0
|
||
total_dropped = 0
|
||
failed_files = []
|
||
|
||
for fname in files:
|
||
file_path = os.path.join(sub, fname)
|
||
try:
|
||
raw = _read_data_file(file_path)
|
||
except Exception as e:
|
||
failed_files.append(f" {fname}: {e}")
|
||
continue
|
||
|
||
total_raw += raw.shape[0]
|
||
try:
|
||
features = _extract_features(raw, f"{lbl}/{fname}")
|
||
except ValueError as e:
|
||
failed_files.append(f" {fname}: {e}")
|
||
continue
|
||
|
||
single_file_cols.add(features.shape[1])
|
||
clean = features.dropna()
|
||
total_dropped += features.shape[0] - clean.shape[0]
|
||
if clean.shape[0] > 0:
|
||
class_frame_list.append(clean)
|
||
|
||
info: dict[str, object] = {
|
||
"label": lbl,
|
||
"num_files": len(files),
|
||
"raw_rows_total": total_raw,
|
||
"dropped_nan_rows": total_dropped,
|
||
"warnings": [],
|
||
}
|
||
|
||
if failed_files:
|
||
info["warnings"].append(f"{len(failed_files)} 个文件加载失败")
|
||
for ff in failed_files:
|
||
rp.add(f" [WARN] {ff}")
|
||
|
||
if len(single_file_cols) > 1:
|
||
info["warnings"].append(
|
||
f"子文件间列数不一致: {sorted(single_file_cols)}"
|
||
)
|
||
col_counts[lbl] = max(single_file_cols)
|
||
elif single_file_cols:
|
||
col_counts[lbl] = single_file_cols.pop()
|
||
else:
|
||
col_counts[lbl] = 0
|
||
|
||
if class_frame_list:
|
||
combined = pd.concat(class_frame_list, ignore_index=True)
|
||
info["samples_after_dropna"] = combined.shape[0]
|
||
info["feature_cols"] = combined.shape[1]
|
||
info["values"] = combined.values
|
||
all_features.append((lbl, combined))
|
||
else:
|
||
info["samples_after_dropna"] = 0
|
||
info["warnings"].append("无有效样本")
|
||
|
||
per_class_info[lbl] = info
|
||
|
||
# 列数一致性
|
||
non_zero = {l: c for l, c in col_counts.items() if c > 0}
|
||
if non_zero and len(set(non_zero.values())) > 1:
|
||
rp.add()
|
||
rp.add("[WARN] 各标签的特征列数不一致(将使用零填充对齐):")
|
||
for lbl, cc in col_counts.items():
|
||
rp.add(f" {lbl}: {cc} 列")
|
||
elif non_zero:
|
||
rp.add(f"[OK] 所有标签特征列数一致: {next(iter(non_zero.values()))} 列")
|
||
|
||
# 样本数统计
|
||
rp.add()
|
||
rp.add("── 各类别样本数 ──")
|
||
sample_counts: dict[str, int] = {}
|
||
for lbl in labels:
|
||
info = per_class_info.get(lbl, {})
|
||
if "error" in info:
|
||
rp.add(f" [{lbl}] 加载失败: {info['error']}")
|
||
continue
|
||
n = info.get("samples_after_dropna", 0)
|
||
sample_counts[lbl] = n
|
||
wflag = ""
|
||
if info.get("warnings"):
|
||
wflag = f" ⚠ {'; '.join(info['warnings'])}"
|
||
rp.add(f" [{lbl}] {n} 行 "
|
||
f"(来自 {info.get('num_files','?')} 个文件, "
|
||
f"原始 {info.get('raw_rows_total','?')} 行, "
|
||
f"丢弃 NaN 行 {info.get('dropped_nan_rows',0)}){wflag}")
|
||
|
||
_analyze_balance(sample_counts, rp)
|
||
_analyze_statistics(all_features, rp)
|
||
_analyze_outliers(all_features, rp)
|
||
|
||
|
||
# ============================================================
|
||
# 分析子模块
|
||
# ============================================================
|
||
|
||
def _analyze_balance(counts: dict[str, int], rp: ReportBuffer):
|
||
rp.add()
|
||
rp.add("── 类别平衡性分析 ──")
|
||
if not counts:
|
||
rp.add("无有效样本,跳过。")
|
||
return
|
||
|
||
values = list(counts.values())
|
||
total = sum(values)
|
||
n_classes = len(values)
|
||
avg = total / n_classes if n_classes else 0
|
||
min_count = min(values)
|
||
max_count = max(values)
|
||
|
||
rp.add(f" 总样本数 : {total}")
|
||
rp.add(f" 类别数 : {n_classes}")
|
||
rp.add(f" 平均每类 : {avg:.1f}")
|
||
rp.add(f" 最少样本类 : {min(counts, key=counts.get)} ({min_count})")
|
||
rp.add(f" 最多样本类 : {max(counts, key=counts.get)} ({max_count})")
|
||
|
||
if min_count == 0:
|
||
rp.add(" [ERROR] 存在样本数为 0 的类别,训练将无法进行!")
|
||
return
|
||
|
||
ratio = max_count / min_count if min_count > 0 else float("inf")
|
||
rp.add(f" 不平衡比例 : {ratio:.2f}:1 (max/min)")
|
||
|
||
std_val = float(np.std(values))
|
||
cv = std_val / avg if avg > 0 else 0
|
||
rp.add(f" 变异系数(CV): {cv:.4f}")
|
||
|
||
if ratio > 5:
|
||
rp.add(" [WARN] 类别严重不平衡 (>5:1),建议进行数据增强或使用类别权重。")
|
||
elif ratio > 3:
|
||
rp.add(" [WARN] 类别较不平衡 (>3:1),可考虑采样策略。")
|
||
else:
|
||
rp.add(" [OK] 类别基本平衡。")
|
||
|
||
# 训练-测试划分预估
|
||
rp.add()
|
||
rp.add(" ── 训练/测试划分预估 (test_size=0.2, stratify) ──")
|
||
for lbl, cnt in sorted(counts.items()):
|
||
test_n = max(1, int(cnt * 0.2))
|
||
train_n = cnt - test_n
|
||
rp.add(f" [{lbl}] 训练 {train_n} / 测试 {test_n} (总计 {cnt})")
|
||
|
||
|
||
def _analyze_statistics(
|
||
all_features: list[tuple[str, pd.DataFrame]],
|
||
rp: ReportBuffer,
|
||
):
|
||
rp.add()
|
||
rp.add("── 特征统计信息 ──")
|
||
if not all_features:
|
||
rp.add("无有效数据,跳过。")
|
||
return
|
||
|
||
# 全局特征统计(将各 array 零填充到相同列数)
|
||
max_cols = max(f[1].shape[1] for f in all_features)
|
||
padded_arrays = []
|
||
for _, df in all_features:
|
||
val = df.values
|
||
if val.shape[1] < max_cols:
|
||
pad = np.zeros((val.shape[0], max_cols - val.shape[1]), dtype=val.dtype)
|
||
val = np.hstack([val, pad])
|
||
padded_arrays.append(val)
|
||
all_values = np.vstack(padded_arrays)
|
||
rp.add(f" 全局特征维度: {all_values.shape} (样本数 × 特征数, 零填充对齐到 {max_cols} 列)")
|
||
rp.add(f" 全局均值 : {np.mean(all_values):.6f}")
|
||
rp.add(f" 全局标准差 : {np.std(all_values):.6f}")
|
||
rp.add(f" 全局最小值 : {np.min(all_values):.6f}")
|
||
rp.add(f" 全局最大值 : {np.max(all_values):.6f}")
|
||
rp.add(f" 全局中位数 : {np.median(all_values):.6f}")
|
||
|
||
# 每个特征维度的统计
|
||
rp.add()
|
||
n_cols = min(all_values.shape[1], 12)
|
||
rp.add(f" ── 前 {n_cols} 个特征维度的分布 (均值 ± 标准差) ──")
|
||
for j in range(n_cols):
|
||
col = all_values[:, j]
|
||
rp.add(
|
||
f" feature{j+1}: "
|
||
f"μ={np.mean(col):.4f} σ={np.std(col):.4f} "
|
||
f"[{np.min(col):.4f}, {np.max(col):.4f}] "
|
||
f"med={np.median(col):.4f}"
|
||
)
|
||
|
||
# 每个类别的简要统计
|
||
rp.add()
|
||
rp.add(" ── 各类别特征统计 ──")
|
||
for lbl, df in all_features:
|
||
val = df.values
|
||
rp.add(
|
||
f" [{lbl}] "
|
||
f"μ={np.mean(val):.4f} σ={np.std(val):.4f} "
|
||
f"范围 [{np.min(val):.4f}, {np.max(val):.4f}]"
|
||
)
|
||
|
||
|
||
def _analyze_outliers(
|
||
all_features: list[tuple[str, pd.DataFrame]],
|
||
rp: ReportBuffer,
|
||
):
|
||
rp.add()
|
||
rp.add("── 离群值检测 (基于 IQR) ──")
|
||
if not all_features:
|
||
rp.add("无有效数据,跳过。")
|
||
return
|
||
|
||
total_outlier_samples = 0
|
||
total_samples = 0
|
||
|
||
for lbl, df in all_features:
|
||
val = df.values
|
||
total_samples += val.shape[0]
|
||
q1 = np.percentile(val, 25, axis=0)
|
||
q3 = np.percentile(val, 75, axis=0)
|
||
iqr = q3 - q1
|
||
lower = q1 - 1.5 * iqr
|
||
upper = q3 + 1.5 * iqr
|
||
|
||
outlier_mask = np.any((val < lower) | (val > upper), axis=1)
|
||
n_outliers = int(np.sum(outlier_mask))
|
||
total_outlier_samples += n_outliers
|
||
pct = n_outliers / val.shape[0] * 100 if val.shape[0] > 0 else 0
|
||
status = "[OK]" if pct < 10 else "[WARN]" if pct < 25 else "[ERROR]"
|
||
rp.add(f" [{lbl}] 离群样本: {n_outliers}/{val.shape[0]} ({pct:.1f}%) {status}")
|
||
|
||
overall_pct = total_outlier_samples / total_samples * 100 if total_samples > 0 else 0
|
||
rp.add()
|
||
rp.add(f" 整体离群比例: {total_outlier_samples}/{total_samples} ({overall_pct:.1f}%)")
|
||
|
||
if overall_pct > 20:
|
||
rp.add(" [WARN] 超过 20% 数据为离群值,请确认数据清洗是否正确。")
|
||
elif overall_pct > 10:
|
||
rp.add(" [INFO] 离群值比例偏高,训练时可能影响收敛。")
|
||
else:
|
||
rp.add(" [OK] 离群值比例正常。")
|
||
|
||
|
||
# ============================================================
|
||
# 命令行入口
|
||
# ============================================================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="Deeplearning 数据质量检查脚本",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
python Scripts/check_data.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
|
||
python Scripts/check_data.py -f "20260408 grap" -l 1 2 3 4 5 6 7 8 9 -o report.txt
|
||
python Scripts/check_data.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9 -r /path/to/project
|
||
""",
|
||
)
|
||
parser.add_argument(
|
||
"-f", "--folder",
|
||
required=True,
|
||
help="Static/ 下的数据目录名(例如 20260319Numbers)",
|
||
)
|
||
parser.add_argument(
|
||
"-l", "--labels",
|
||
nargs="+",
|
||
required=True,
|
||
help="类别标签列表,空格分隔(例如 0 1 2 3 4 或 A B C)",
|
||
)
|
||
parser.add_argument(
|
||
"-o", "--output",
|
||
default=None,
|
||
help="将报告保存到指定文件路径",
|
||
)
|
||
parser.add_argument(
|
||
"-r", "--root",
|
||
default=None,
|
||
help="项目根目录(默认为脚本的上级目录,即 Deeplearning/)",
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 确定项目根目录
|
||
if args.root:
|
||
root = args.root
|
||
else:
|
||
script_dir = Path(__file__).resolve().parent
|
||
root = str(script_dir.parent)
|
||
|
||
root = os.path.abspath(root)
|
||
|
||
if not os.path.isdir(root):
|
||
print(f"[ERROR] 项目根目录不存在: {root}")
|
||
sys.exit(1)
|
||
|
||
rp = ReportBuffer(output_path=args.output)
|
||
check_tabular_project(
|
||
root=root,
|
||
folder=args.folder,
|
||
labels=args.labels,
|
||
rp=rp,
|
||
)
|
||
rp.save()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |