657 lines
24 KiB
Python
657 lines
24 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
数据可视化脚本
|
||
===============
|
||
加载 Static/ 下的数据目录,生成多种可视化图表:
|
||
- 类别分布柱状图
|
||
- 特征分布直方图(各类叠加)
|
||
- 特征箱线图(前 N 个特征)
|
||
- PCA 降维散点图 + 置信椭圆
|
||
- t-SNE 降维散点图
|
||
- 各类别均值/标准差对比热力图
|
||
- 全局特征相关性热力图
|
||
- 全局特征分布概览
|
||
|
||
用法:
|
||
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
|
||
python Scripts/visualize.py -f "20260408 grap" -l 1 2 3 4 5 6 7 8 9 --max-features 20
|
||
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9 --no-tsne
|
||
|
||
输出目录: Visualizations/<folder>/
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
import unicodedata
|
||
from pathlib import Path
|
||
import warnings
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
import seaborn as sns
|
||
|
||
warnings.filterwarnings("ignore", category=FutureWarning)
|
||
warnings.filterwarnings("ignore", category=UserWarning)
|
||
|
||
# ============================================================
|
||
# 数据加载工具(与 check_data.py 保持一致)
|
||
# ============================================================
|
||
|
||
DEFAULT_FILE_CLASSES = ("xlsx", "xls", "csv")
|
||
|
||
|
||
def _has_supported_extension(filename: str, file_classes=DEFAULT_FILE_CLASSES) -> bool:
|
||
ext = os.path.splitext(filename)[1].lower().lstrip(".")
|
||
return ext in file_classes
|
||
|
||
|
||
def _read_data_file(file_path: str) -> pd.DataFrame:
|
||
ext = os.path.splitext(file_path)[1].lower()
|
||
if ext == ".csv":
|
||
return pd.read_csv(file_path)
|
||
if ext in (".xls", ".xlsx"):
|
||
return pd.read_excel(file_path)
|
||
raise ValueError(f"Unsupported file format: {ext}: {file_path}")
|
||
|
||
|
||
def _strip_zero_width(s: str) -> str:
|
||
if not isinstance(s, str):
|
||
return s
|
||
return s.translate({0x200B: None, 0x200C: None, 0x200D: None, 0xFEFF: None})
|
||
|
||
|
||
def _canonicalize_name(name: str) -> str:
|
||
name = unicodedata.normalize("NFKC", name)
|
||
return _strip_zero_width(name)
|
||
|
||
|
||
def _normalize_for_compare(name: str) -> str:
|
||
n = _canonicalize_name(name)
|
||
n = n.replace("_", " ")
|
||
n = " ".join(n.split())
|
||
return n.lower()
|
||
|
||
|
||
def _find_matching_file(folder: str, expected_name: str):
|
||
expected = _canonicalize_name(expected_name)
|
||
try:
|
||
entries = os.listdir(folder)
|
||
except FileNotFoundError:
|
||
return None
|
||
for f in entries:
|
||
if _canonicalize_name(f) == expected:
|
||
return f
|
||
expected_lower = expected.lower()
|
||
for f in entries:
|
||
if _canonicalize_name(f).lower() == expected_lower:
|
||
return f
|
||
expected_relaxed = _normalize_for_compare(expected_name)
|
||
for f in entries:
|
||
if _normalize_for_compare(f) == expected_relaxed:
|
||
return f
|
||
return None
|
||
|
||
|
||
def _find_matching_file_by_label(folder: str, label_name, file_classes):
|
||
for ext in file_classes:
|
||
expected_name = f"{label_name}.{ext}"
|
||
match = _find_matching_file(folder, expected_name)
|
||
if match is not None:
|
||
return match
|
||
return None
|
||
|
||
|
||
def _extract_features(df: pd.DataFrame, source: str) -> pd.DataFrame:
|
||
even_cols = [c for i, c in enumerate(df.columns) if i % 2 == 1]
|
||
if not even_cols:
|
||
raise ValueError(f"没有找到偶数列(特征列)。请检查文件: {source}")
|
||
features = df[even_cols].copy()
|
||
for c in features.columns:
|
||
features[c] = pd.to_numeric(features[c], errors="coerce")
|
||
return features
|
||
|
||
|
||
# ============================================================
|
||
# 数据加载
|
||
# ============================================================
|
||
|
||
def load_all_data(root: str, folder: str, labels: list[str]):
|
||
data_dir = os.path.join(root, "Static", folder)
|
||
if not os.path.isdir(data_dir):
|
||
print(f"[ERROR] 目录不存在: {data_dir}")
|
||
sys.exit(1)
|
||
|
||
has_all_files = all(
|
||
_find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES) is not None
|
||
for lbl in labels
|
||
)
|
||
has_all_subfolders = all(
|
||
os.path.isdir(os.path.join(data_dir, str(lbl)))
|
||
and any(_has_supported_extension(f) for f in os.listdir(os.path.join(data_dir, str(lbl))))
|
||
for lbl in labels
|
||
)
|
||
|
||
if has_all_files and not has_all_subfolders:
|
||
return _load_single_file_mode(data_dir, labels)
|
||
elif has_all_subfolders and not has_all_files:
|
||
return _load_multi_folder_mode(data_dir, labels)
|
||
else:
|
||
print("[WARN] 数据模式不明确,尝试单文件模式...")
|
||
return _load_single_file_mode(data_dir, labels)
|
||
|
||
|
||
def _load_single_file_mode(data_dir: str, labels: list[str]):
|
||
all_features = []
|
||
col_counts = {}
|
||
label_names = []
|
||
|
||
for lbl in labels:
|
||
fname = _find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES)
|
||
if fname is None:
|
||
print(f"[WARN] 标签 {lbl} 找不到文件,跳过")
|
||
continue
|
||
file_path = os.path.join(data_dir, fname)
|
||
try:
|
||
raw = _read_data_file(file_path)
|
||
except Exception as e:
|
||
print(f"[ERROR] 读取 {file_path} 失败: {e}")
|
||
continue
|
||
try:
|
||
features = _extract_features(raw, fname)
|
||
except ValueError as e:
|
||
print(f"[ERROR] {e}")
|
||
continue
|
||
|
||
clean = features.dropna()
|
||
if clean.shape[0] == 0:
|
||
print(f"[WARN] 标签 {lbl} 去除 NaN 后无样本,跳过")
|
||
continue
|
||
|
||
col_counts[lbl] = clean.shape[1]
|
||
all_features.append((lbl, clean))
|
||
label_names.append(lbl)
|
||
|
||
return _build_arrays(all_features, label_names, col_counts)
|
||
|
||
|
||
def _load_multi_folder_mode(data_dir: str, labels: list[str]):
|
||
all_features = []
|
||
col_counts = {}
|
||
label_names = []
|
||
|
||
for lbl in labels:
|
||
sub = os.path.join(data_dir, str(lbl))
|
||
if not os.path.isdir(sub):
|
||
print(f"[WARN] 标签 {lbl} 子目录不存在,跳过")
|
||
continue
|
||
|
||
files = sorted([f for f in os.listdir(sub) if _has_supported_extension(f)])
|
||
if not files:
|
||
print(f"[WARN] 标签 {lbl} 子目录无文件,跳过")
|
||
continue
|
||
|
||
frames = []
|
||
max_cols_in_class = 0
|
||
for fname in files:
|
||
file_path = os.path.join(sub, fname)
|
||
try:
|
||
raw = _read_data_file(file_path)
|
||
except Exception as e:
|
||
print(f"[WARN] 读取 {file_path} 失败: {e}")
|
||
continue
|
||
try:
|
||
feat = _extract_features(raw, f"{lbl}/{fname}")
|
||
except ValueError as e:
|
||
print(f"[WARN] {e}")
|
||
continue
|
||
clean = feat.dropna()
|
||
if clean.shape[0] > 0:
|
||
frames.append(clean)
|
||
max_cols_in_class = max(max_cols_in_class, clean.shape[1])
|
||
|
||
if not frames:
|
||
print(f"[WARN] 标签 {lbl} 无有效样本,跳过")
|
||
continue
|
||
|
||
padded_frames = []
|
||
for f in frames:
|
||
if f.shape[1] < max_cols_in_class:
|
||
pad = np.zeros((f.shape[0], max_cols_in_class - f.shape[1]))
|
||
padded = pd.DataFrame(
|
||
np.hstack([f.values, pad]),
|
||
columns=list(f.columns) + [f"_pad_{i}" for i in range(max_cols_in_class - f.shape[1])],
|
||
)
|
||
padded_frames.append(padded)
|
||
else:
|
||
padded_frames.append(f)
|
||
|
||
combined = pd.concat(padded_frames, ignore_index=True)
|
||
col_counts[lbl] = combined.shape[1]
|
||
all_features.append((lbl, combined))
|
||
label_names.append(lbl)
|
||
|
||
return _build_arrays(all_features, label_names, col_counts)
|
||
|
||
|
||
def _build_arrays(all_features, label_names, col_counts):
|
||
if not all_features:
|
||
print("[ERROR] 没有加载到任何有效数据")
|
||
sys.exit(1)
|
||
|
||
max_cols = max(c for c in col_counts.values())
|
||
|
||
X_list = []
|
||
y_list = []
|
||
for idx, (lbl, df) in enumerate(all_features):
|
||
val = df.values
|
||
if val.shape[1] < max_cols:
|
||
pad = np.zeros((val.shape[0], max_cols - val.shape[1]), dtype=val.dtype)
|
||
val = np.hstack([val, pad])
|
||
X_list.append(val)
|
||
y_list.append(np.full(val.shape[0], idx, dtype=int))
|
||
|
||
X = np.vstack(X_list)
|
||
y = np.concatenate(y_list)
|
||
|
||
return X, y, label_names, all_features, col_counts
|
||
|
||
|
||
# ============================================================
|
||
# 可视化函数
|
||
# ============================================================
|
||
|
||
TAB10 = plt.cm.tab10.colors
|
||
|
||
|
||
def _ensure_dir(path: str):
|
||
os.makedirs(path, exist_ok=True)
|
||
|
||
|
||
def plot_class_distribution(y, label_names, out_dir: str):
|
||
fig, ax = plt.subplots(figsize=(max(8, len(label_names) * 0.6), 5))
|
||
counts = [int(np.sum(y == i)) for i in range(len(label_names))]
|
||
colors = [TAB10[i % 10] for i in range(len(label_names))]
|
||
bars = ax.bar(label_names, counts, color=colors, edgecolor="white", linewidth=0.8)
|
||
for bar, cnt in zip(bars, counts):
|
||
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + max(counts) * 0.01,
|
||
str(cnt), ha="center", va="bottom", fontsize=9)
|
||
ax.set_xlabel("类别")
|
||
ax.set_ylabel("样本数")
|
||
ax.set_title(f"类别分布 (总计 {sum(counts)} 样本, {len(label_names)} 类)")
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "01_class_distribution.png")
|
||
fig.savefig(path, dpi=150)
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
|
||
def plot_feature_histograms(all_features, out_dir: str, max_features: int = 12):
|
||
n_features = min(max_features, all_features[0][1].shape[1])
|
||
n_cols = 4
|
||
n_rows = (n_features + n_cols - 1) // n_cols
|
||
|
||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
|
||
axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
|
||
|
||
colors = [TAB10[i % 10] for i in range(len(all_features))]
|
||
|
||
for j in range(n_features):
|
||
ax = axes[j]
|
||
for idx, (lbl, df) in enumerate(all_features):
|
||
if j < df.shape[1]:
|
||
col = df.iloc[:, j].values
|
||
ax.hist(col, bins=40, density=True, alpha=0.4, color=colors[idx],
|
||
label=f"类 {lbl}")
|
||
ax.set_title(f"Feature {j+1}")
|
||
ax.set_xlabel("值")
|
||
ax.set_ylabel("密度")
|
||
if n_features <= 8:
|
||
ax.legend(fontsize=7, loc="upper right")
|
||
|
||
for j in range(n_features, len(axes)):
|
||
axes[j].set_visible(False)
|
||
|
||
if n_features > 8:
|
||
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[i], alpha=0.4)
|
||
for i in range(len(all_features))]
|
||
fig.legend(handles, [lbl for lbl, _ in all_features],
|
||
loc="lower center", ncol=min(10, len(all_features)), fontsize=7)
|
||
|
||
fig.suptitle(f"特征分布直方图(各类叠加,前 {n_features} 维)", fontsize=13, y=1.01)
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "02_feature_histograms.png")
|
||
fig.savefig(path, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
|
||
def plot_feature_boxplots(all_features, out_dir: str, max_features: int = 20):
|
||
n_features = min(max_features, all_features[0][1].shape[1])
|
||
n_cols = 4
|
||
n_rows = (n_features + n_cols - 1) // n_cols
|
||
|
||
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 3.5 * n_rows))
|
||
axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
|
||
|
||
for j in range(n_features):
|
||
ax = axes[j]
|
||
data_list = []
|
||
positions = []
|
||
labels_for_box = []
|
||
for idx, (lbl, df) in enumerate(all_features):
|
||
if j < df.shape[1]:
|
||
data_list.append(df.iloc[:, j].values)
|
||
positions.append(idx + 1)
|
||
labels_for_box.append(str(lbl))
|
||
|
||
bp = ax.boxplot(data_list, positions=positions, labels=labels_for_box,
|
||
patch_artist=True, widths=0.6, showfliers=True,
|
||
flierprops={"marker": ".", "markersize": 2, "alpha": 0.3})
|
||
for patch, idx in zip(bp["boxes"], range(len(data_list))):
|
||
patch.set_facecolor(TAB10[idx % 10])
|
||
patch.set_alpha(0.6)
|
||
ax.set_title(f"Feature {j+1}")
|
||
ax.set_xlabel("类别")
|
||
ax.tick_params(axis="x", rotation=0, labelsize=8)
|
||
|
||
for j in range(n_features, len(axes)):
|
||
axes[j].set_visible(False)
|
||
|
||
fig.suptitle(f"特征箱线图(各类别对比,前 {n_features} 维)", fontsize=13, y=1.01)
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "03_feature_boxplots.png")
|
||
fig.savefig(path, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
|
||
def _plot_confidence_ellipse(ax, mean, cov, color, alpha=0.2, n_std=1.0):
|
||
from matplotlib.patches import Ellipse
|
||
vals, vecs = np.linalg.eigh(cov)
|
||
order = vals.argsort()[::-1]
|
||
vals = vals[order]
|
||
vecs = vecs[:, order]
|
||
angle = np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0]))
|
||
width, height = 2 * n_std * np.sqrt(vals)
|
||
ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle,
|
||
facecolor=color, alpha=alpha, edgecolor=color, linewidth=0.8)
|
||
ax.add_patch(ellipse)
|
||
|
||
|
||
def plot_pca(X, y, label_names, out_dir: str):
|
||
from sklearn.decomposition import PCA
|
||
|
||
pca = PCA(n_components=2, random_state=42)
|
||
X_pca = pca.fit_transform(X)
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
|
||
colors = [TAB10[i % 10] for i in range(len(label_names))]
|
||
|
||
# 散点图
|
||
ax = axes[0]
|
||
for i, lbl in enumerate(label_names):
|
||
mask = y == i
|
||
ax.scatter(X_pca[mask, 0], X_pca[mask, 1], c=[colors[i]], label=f"类 {lbl}",
|
||
alpha=0.5, s=3, edgecolors="none")
|
||
ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
|
||
ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
|
||
ax.set_title("PCA 降维散点图")
|
||
ax.legend(fontsize=7, markerscale=3, loc="best")
|
||
|
||
# 质心+椭圆
|
||
ax2 = axes[1]
|
||
for i, lbl in enumerate(label_names):
|
||
mask = y == i
|
||
class_points = X_pca[mask]
|
||
mean = class_points.mean(axis=0)
|
||
ax2.scatter(mean[0], mean[1], c=[colors[i]], s=80, marker="X",
|
||
edgecolors="black", linewidths=0.8, zorder=5)
|
||
ax2.annotate(str(lbl), (mean[0], mean[1]), fontsize=8, ha="center", va="bottom",
|
||
fontweight="bold", xytext=(0, 4), textcoords="offset points")
|
||
if class_points.shape[0] > 2:
|
||
cov = np.cov(class_points.T)
|
||
_plot_confidence_ellipse(ax2, mean, cov, color=colors[i], alpha=0.25)
|
||
|
||
ax2.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
|
||
ax2.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
|
||
ax2.set_title("PCA 质心 + 1σ 置信椭圆")
|
||
|
||
fig.suptitle("PCA 降维分析", fontsize=14)
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "04_pca.png")
|
||
fig.savefig(path, dpi=150)
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
# 方差解释率
|
||
fig2, ax = plt.subplots(figsize=(8, 4))
|
||
n = min(30, len(pca.explained_variance_ratio_))
|
||
cumsum = np.cumsum(pca.explained_variance_ratio_[:n])
|
||
ax.bar(range(1, n + 1), pca.explained_variance_ratio_[:n],
|
||
alpha=0.6, color="steelblue", label="个体")
|
||
ax.plot(range(1, n + 1), cumsum, "ro-", markersize=4, label="累计")
|
||
ax.set_xlabel("主成分")
|
||
ax.set_ylabel("方差解释率")
|
||
ax.set_title("PCA 方差解释率")
|
||
ax.legend()
|
||
fig2.tight_layout()
|
||
path2 = os.path.join(out_dir, "04_pca_variance.png")
|
||
fig2.savefig(path2, dpi=150)
|
||
plt.close(fig2)
|
||
print(f"[OK] {path2}")
|
||
|
||
|
||
def plot_tsne(X, y, label_names, out_dir: str, max_samples: int = 5000):
|
||
from sklearn.manifold import TSNE
|
||
|
||
if X.shape[0] > max_samples:
|
||
print(f"[INFO] t-SNE: 样本过多 ({X.shape[0]}), 分层抽样至 {max_samples}")
|
||
indices = []
|
||
per_class = max_samples // len(label_names)
|
||
for i in range(len(label_names)):
|
||
idx_i = np.where(y == i)[0]
|
||
if len(idx_i) <= per_class:
|
||
indices.extend(idx_i.tolist())
|
||
else:
|
||
rng = np.random.RandomState(42)
|
||
indices.extend(rng.choice(idx_i, per_class, replace=False).tolist())
|
||
indices = np.array(indices)
|
||
X_sub = X[indices]
|
||
y_sub = y[indices]
|
||
else:
|
||
X_sub = X
|
||
y_sub = y
|
||
|
||
print("[INFO] 正在计算 t-SNE(可能需要一些时间)...")
|
||
perplexity = min(50, max(5, X_sub.shape[0] // 3))
|
||
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity,
|
||
n_iter=1000, verbose=0)
|
||
X_tsne = tsne.fit_transform(X_sub)
|
||
|
||
colors = [TAB10[i % 10] for i in range(len(label_names))]
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 8))
|
||
for i, lbl in enumerate(label_names):
|
||
mask = y_sub == i
|
||
ax.scatter(X_tsne[mask, 0], X_tsne[mask, 1], c=[colors[i]], label=f"类 {lbl}",
|
||
alpha=0.5, s=3, edgecolors="none")
|
||
ax.set_xlabel("t-SNE 1")
|
||
ax.set_ylabel("t-SNE 2")
|
||
ax.set_title(f"t-SNE 降维散点图 (n={X_sub.shape[0]}, perplexity={perplexity})")
|
||
ax.legend(fontsize=7, markerscale=3, loc="best")
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "05_tsne.png")
|
||
fig.savefig(path, dpi=150)
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
|
||
def plot_class_mean_std_heatmap(all_features, label_names, out_dir: str, max_features: int = 30):
|
||
n_features = min(max_features, all_features[0][1].shape[1])
|
||
n_classes = len(label_names)
|
||
|
||
mean_matrix = np.zeros((n_classes, n_features))
|
||
std_matrix = np.zeros((n_classes, n_features))
|
||
|
||
for i, (lbl, df) in enumerate(all_features):
|
||
for j in range(min(n_features, df.shape[1])):
|
||
col = df.iloc[:, j].values
|
||
mean_matrix[i, j] = np.mean(col)
|
||
std_matrix[i, j] = np.std(col)
|
||
for j in range(df.shape[1], n_features):
|
||
mean_matrix[i, j] = 0.0
|
||
std_matrix[i, j] = 0.0
|
||
|
||
fig, axes = plt.subplots(1, 2, figsize=(max(10, n_features * 0.35), max(5, n_classes * 0.5)))
|
||
|
||
sns.heatmap(mean_matrix, ax=axes[0], cmap="RdBu_r", center=0,
|
||
xticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
|
||
yticklabels=label_names,
|
||
annot=n_features <= 20, fmt=".3f" if n_features <= 20 else "",
|
||
linewidths=0.5, cbar_kws={"label": "均值", "shrink": 0.8})
|
||
axes[0].set_title("各类别特征均值")
|
||
axes[0].set_xlabel("特征维度")
|
||
|
||
sns.heatmap(std_matrix, ax=axes[1], cmap="YlOrRd",
|
||
xticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
|
||
yticklabels=label_names,
|
||
annot=n_features <= 20, fmt=".3f" if n_features <= 20 else "",
|
||
linewidths=0.5, cbar_kws={"label": "标准差", "shrink": 0.8})
|
||
axes[1].set_title("各类别特征标准差")
|
||
axes[1].set_xlabel("特征维度")
|
||
|
||
fig.suptitle("各类别特征统计对比", fontsize=13)
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "06_class_mean_std_heatmap.png")
|
||
fig.savefig(path, dpi=150, bbox_inches="tight")
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
|
||
def plot_correlation_heatmap(X, out_dir: str, max_features: int = 30):
|
||
n_features = min(max_features, X.shape[1])
|
||
X_sub = X[:, :n_features]
|
||
|
||
corr = np.corrcoef(X_sub.T)
|
||
|
||
fig, ax = plt.subplots(figsize=(max(10, n_features * 0.5), max(8, n_features * 0.45)))
|
||
sns.heatmap(corr, ax=ax, cmap="RdBu_r", center=0, vmin=-1, vmax=1,
|
||
xticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
|
||
yticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
|
||
linewidths=0.1, cbar_kws={"label": "Pearson r", "shrink": 0.8})
|
||
ax.set_title(f"特征相关性矩阵 (前 {n_features} 维)")
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "07_correlation_heatmap.png")
|
||
fig.savefig(path, dpi=150)
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
|
||
def plot_global_distribution(X, out_dir: str):
|
||
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
|
||
|
||
all_vals = X.flatten()
|
||
axes[0].hist(all_vals, bins=100, color="steelblue", alpha=0.7, edgecolor="white",
|
||
linewidth=0.3)
|
||
axes[0].axvline(np.mean(all_vals), color="red", linestyle="--",
|
||
label=f"均值={np.mean(all_vals):.4f}")
|
||
axes[0].axvline(np.median(all_vals), color="orange", linestyle="--",
|
||
label=f"中位数={np.median(all_vals):.4f}")
|
||
axes[0].set_xlabel("特征值")
|
||
axes[0].set_ylabel("频数")
|
||
axes[0].set_title("全局特征值分布")
|
||
axes[0].legend()
|
||
|
||
means = np.mean(X, axis=0)
|
||
stds = np.std(X, axis=0)
|
||
n_features = min(X.shape[1], 50)
|
||
axes[1].errorbar(range(1, n_features + 1), means[:n_features], yerr=stds[:n_features],
|
||
fmt="o", markersize=3, capsize=2, color="steelblue", alpha=0.7)
|
||
axes[1].set_xlabel("特征维度")
|
||
axes[1].set_ylabel("均值 ± 标准差")
|
||
axes[1].set_title(f"各维度均值与标准差 (前 {n_features} 维)")
|
||
|
||
fig.suptitle("全局特征概览", fontsize=13)
|
||
fig.tight_layout()
|
||
path = os.path.join(out_dir, "08_global_distribution.png")
|
||
fig.savefig(path, dpi=150)
|
||
plt.close(fig)
|
||
print(f"[OK] {path}")
|
||
|
||
|
||
# ============================================================
|
||
# 命令行入口
|
||
# ============================================================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description="Deeplearning 数据可视化脚本",
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
|
||
python Scripts/visualize.py -f "20260408 grap" -l 1 2 3 4 5 6 7 8 9
|
||
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9 --no-tsne --max-features 15
|
||
""",
|
||
)
|
||
parser.add_argument("-f", "--folder", required=True,
|
||
help="Static/ 下的数据目录名")
|
||
parser.add_argument("-l", "--labels", nargs="+", required=True,
|
||
help="类别标签列表,空格分隔")
|
||
parser.add_argument("-r", "--root", default=None,
|
||
help="项目根目录(默认为 Deeplearning/)")
|
||
parser.add_argument("--max-features", type=int, default=20,
|
||
help="可视化中显示的最大特征维度数 (默认 20)")
|
||
parser.add_argument("--no-tsne", action="store_true",
|
||
help="跳过 t-SNE 计算")
|
||
parser.add_argument("--no-pca", action="store_true",
|
||
help="跳过 PCA 计算")
|
||
parser.add_argument("--tsne-max-samples", type=int, default=5000,
|
||
help="t-SNE 最大抽样数 (默认 5000)")
|
||
|
||
args = parser.parse_args()
|
||
|
||
if args.root:
|
||
root = args.root
|
||
else:
|
||
root = str(Path(__file__).resolve().parent.parent)
|
||
root = os.path.abspath(root)
|
||
|
||
print(f"加载数据: {root}/Static/{args.folder}")
|
||
X, y, label_names, all_features, col_counts = load_all_data(root, args.folder, args.labels)
|
||
print(f" 样本数: {X.shape[0]}, 特征维度: {X.shape[1]}, 类别数: {len(label_names)}")
|
||
for lbl in label_names:
|
||
cnt = int(np.sum(y == label_names.index(lbl)))
|
||
print(f" 类 {lbl}: {cnt} 样本, {col_counts.get(lbl, X.shape[1])} 列")
|
||
|
||
out_dir = os.path.join(root, "Visualizations", args.folder)
|
||
_ensure_dir(out_dir)
|
||
print(f"\n输出目录: {out_dir}\n")
|
||
|
||
plt.rcParams["font.family"] = "sans-serif"
|
||
plt.rcParams["font.sans-serif"] = ["DejaVu Sans"]
|
||
|
||
print("生成可视化图表...\n")
|
||
plot_class_distribution(y, label_names, out_dir)
|
||
plot_feature_histograms(all_features, out_dir,
|
||
max_features=min(args.max_features, X.shape[1]))
|
||
plot_feature_boxplots(all_features, out_dir,
|
||
max_features=min(args.max_features, X.shape[1]))
|
||
if not args.no_pca:
|
||
plot_pca(X, y, label_names, out_dir)
|
||
if not args.no_tsne:
|
||
plot_tsne(X, y, label_names, out_dir, max_samples=args.tsne_max_samples)
|
||
plot_class_mean_std_heatmap(all_features, label_names, out_dir,
|
||
max_features=min(args.max_features, X.shape[1]))
|
||
plot_correlation_heatmap(X, out_dir, max_features=min(30, X.shape[1]))
|
||
plot_global_distribution(X, out_dir)
|
||
|
||
print(f"\n全部图表已保存到: {out_dir}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |