Compare commits

...

No commits in common. "a7e95141d21ad35ac198667fe38ba614da3fe285" and "b2a3bc393e3288a743f91cf7ef13fb94fa66e25f" have entirely different histories.

46 changed files with 1891 additions and 3409 deletions

5
.gitignore vendored
View File

@ -1,4 +1 @@
Static
Result
.vscode
__pycache__
lazy-lock.json

View File

@ -1,5 +0,0 @@
from .divSet import divSet
from .loadData import load_data
from .saveToXlsx import save_to_xlsx
__all__ = ["divSet", "load_data", "save_to_xlsx"]

View File

@ -1,45 +0,0 @@
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
def divSet(data, labels=None, test_size=0.2, random_state=None):
"""Split data, scale features, and encode labels.
This module is the canonical location for dataset splitting utilities.
"""
encoder = LabelEncoder()
# 最后一列是标签
X = data.iloc[:, :-1]
y = data.iloc[:, -1]
if labels:
encoder.fit(labels)
else:
encoder.fit(y)
# 优先使用分层抽样,尽量保证每个类别在训练集和测试集都出现。
stratify_target = y if y.nunique() > 1 else None
try:
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state, stratify=stratify_target
)
except ValueError:
# 当样本过少等情况下分层失败,回退到普通随机划分。
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state
)
# 标准化特征
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 编码标签
y_train = encoder.transform(y_train.values)
y_test = encoder.transform(y_test.values)
return X_train, X_test, y_train, y_test, encoder
__all__ = ["divSet"]

View File

@ -1,260 +0,0 @@
import os
import unicodedata
import pandas as pd
STATIC_PATH = './Static'
DEFAULT_FILE_CLASSES = ('xlsx', 'xls', 'csv')
# 从文件夹中读取所有数据文件,支持 xls/xlsx/csv
# labelNames为label的名字如果不提供则默认为文件名
def load_data(folder, labelNames):
# 检查folder参数
if folder is None:
raise ValueError("The 'folder' parameter is required.")
# 检查labelNames参数
if labelNames is None:
raise ValueError("The 'labelNames' parameter is required if 'folder' does not contain labels.")
folder = os.path.join(STATIC_PATH, folder)
# 看看有没有元数据文件夹
if not os.path.isdir(folder):
raise ValueError(f"The folder '{folder}' does not exist.")
file_classes = DEFAULT_FILE_CLASSES
# 自动检测数据组织方式
is_dir_mode = _detect_data_mode(folder=folder, labelNames=labelNames, fileClasses=file_classes)
mode_name = 'multi-folder mode' if is_dir_mode else 'single-file mode'
print(f"Auto detected data mode: {mode_name}")
if not is_dir_mode:
data = load_from_file(folder=folder, labelNames=labelNames, fileClasses=file_classes)
else:
data = load_from_folder(folder=folder, labelNames=labelNames, fileClasses=file_classes)
print(data)
return data
def load_from_folder(folder, labelNames, fileClasses):
all_features = []
for labelName in labelNames:
subfolder = os.path.join(folder, str(labelName))
if os.path.exists(subfolder) and os.path.isdir(subfolder):
fileNames = [f for f in os.listdir(subfolder) if _has_supported_extension(f, fileClasses)]
max_row_length = get_max_row_len(subfolder, fileNames)
features = []
for fileName in fileNames:
file_path = os.path.join(subfolder, fileName)
features.append(load_xlsx(file_path, labelName, max_row_length, 'zero'))
if features:
all_features.append(pd.concat(features, ignore_index=True))
# 将所有标签的数据合并
return pd.concat(all_features, ignore_index=True)
def load_from_file(folder, labelNames, fileClasses):
# 构建期望的文件名label + .扩展名),并在目录中进行健壮匹配
# 去除零宽字符、Unicode 规范化、大小写不敏感)
actual_file_names = []
missing = []
for labelName in labelNames:
match = _find_matching_file_by_label(folder, labelName, fileClasses)
if match is None:
missing.append(f"{labelName}.<{'/'.join(fileClasses)}>")
else:
actual_file_names.append(match)
if missing:
available = sorted(os.listdir(folder))
raise FileNotFoundError(
"The following files were not found (after normalization): "
+ ", ".join(missing)
+ f". Available files: {available}"
)
# 获取数据的最大行数(使用实际匹配到的文件名)
max_row_length = get_max_row_len(folder, actual_file_names)
all_features = []
for i, fileName in enumerate(actual_file_names):
file_path = os.path.join(folder, fileName)
features = load_xlsx(file_path, labelNames[i], max_row_length, 'zero')
all_features.append(features)
return pd.concat(all_features, ignore_index=True)
def load_xlsx(fileName, labelName, max_row_length=1000, fill_rule=None):
df = _read_data_file(fileName)
# 提取偶数列
features = df.iloc[0:, 1::2]
features.dropna(inplace=True)
features.reset_index(drop=True, inplace=True)
features = features.T
# 补全每一行到指定长度
features = features.apply(lambda row: fill_to_len(row, max_row_length, fill_rule), axis=1)
# 获取实际的列数
actual_columns = features.shape[1]
features['label'] = labelName
features.columns = [f'feature{i+1}' for i in range(actual_columns)] + ['label']
return features
def fill_to_len(row, length=1000, rule=None):
if len(row) >= length:
return row.iloc[:length].reset_index(drop=True)
fill_value = 0
if rule == 'min':
fill_value = row.min()
elif rule == 'mean':
fill_value = row.mean()
elif rule == 'zero':
fill_value = 0
fill_values = pd.Series([fill_value] * (length - len(row)))
return pd.concat([row, fill_values], ignore_index=True)
def get_max_row_len(folder, filenames):
max_len = 0
for filename in filenames:
df = _read_data_file(os.path.join(folder, filename))
max_len = max(max_len, df.shape[0])
return max_len
def _read_data_file(file_path: str):
ext = os.path.splitext(file_path)[1].lower()
if ext == '.csv':
return pd.read_csv(file_path)
if ext in ('.xls', '.xlsx'):
return pd.read_excel(file_path)
raise ValueError(
f"Unsupported file format: {ext}. Only .xls, .xlsx, and .csv are supported. "
f"File: {file_path}"
)
def _has_supported_extension(filename: str, fileClasses) -> bool:
ext = os.path.splitext(filename)[1].lower().lstrip('.')
return ext in fileClasses
# ---------- 内部工具函数:处理包含零宽字符或不同 Unicode 形式的文件名匹配 ----------
def _strip_zero_width(s: str) -> str:
# 移除常见零宽字符U+200B, U+200C, U+200D, U+FEFF
if not isinstance(s, str):
return s
return s.translate({
0x200B: None,
0x200C: None,
0x200D: None,
0xFEFF: None,
})
def _canonicalize_name(name: str) -> str:
# 规范化到 NFKC并移除零宽字符
name = unicodedata.normalize('NFKC', name)
name = _strip_zero_width(name)
return name
def _normalize_for_compare(name: str) -> str:
# 进一步规范化用于宽松比较
n = _canonicalize_name(name)
n = n.replace('_', ' ')
n = ' '.join(n.split())
return n.lower()
def _find_matching_file(folder: str, expected_name: str):
# 首先进行严格匹配(规范化后相等)
expected = _canonicalize_name(expected_name)
try:
entries = os.listdir(folder)
except FileNotFoundError:
return None
for f in entries:
if _canonicalize_name(f) == expected:
return f
# 次要策略:大小写不敏感比较
expected_lower = expected.lower()
for f in entries:
if _canonicalize_name(f).lower() == expected_lower:
return f
# 宽松策略:将下划线当作空格处理,并折叠空白
expected_relaxed = _normalize_for_compare(expected_name)
for f in entries:
if _normalize_for_compare(f) == expected_relaxed:
return f
return None
def _find_matching_file_by_label(folder: str, label_name, fileClasses):
for ext in fileClasses:
expected_name = f"{label_name}.{ext}"
match = _find_matching_file(folder, expected_name)
if match is not None:
return match
return None
def _detect_data_mode(folder: str, labelNames, fileClasses) -> bool:
"""Auto detect data organization mode.
Returns:
True: multi-folder mode (folder/label/*.ext)
False: single-file mode (folder/label.ext)
"""
# 判断是否满足多文件夹模式:每个 label 对应一个子目录,且至少有一个目标后缀文件
has_all_label_subfolders = True
for label in labelNames:
subfolder = os.path.join(folder, str(label))
if not (os.path.isdir(subfolder) and any(_has_supported_extension(f, fileClasses) for f in os.listdir(subfolder))):
has_all_label_subfolders = False
break
# 判断是否满足单文件模式:每个 label 能匹配到对应文件
has_all_label_files = True
for label in labelNames:
if _find_matching_file_by_label(folder, label, fileClasses) is None:
has_all_label_files = False
break
if has_all_label_subfolders and not has_all_label_files:
return True
if has_all_label_files and not has_all_label_subfolders:
return False
if has_all_label_subfolders and has_all_label_files:
raise ValueError(
"Auto detect found both valid layouts under the same folder. "
"Please keep only one layout type (either subfolders or root files) for each label."
)
raise ValueError(
"Auto detect failed: neither single-file nor multi-folder layout matches all labels. "
"Please verify folder structure and labelNames."
)
__all__ = ['load_data']

View File

@ -1,165 +0,0 @@
import os
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
def save_to_xlsx(project_name, file_name, data):
folder_path = f'Result/{project_name}'
os.makedirs(folder_path, exist_ok=True)
data.to_excel(f'{folder_path}/{file_name}.xlsx', index=True)
print('Save successed to ' + f'{folder_path}/{file_name}.xlsx')
save_to_pic(project_name=project_name, file_name=file_name)
return
def save_to_pic(project_name, file_name):
os.makedirs(f'Result/{project_name}', exist_ok=True)
if file_name == 'pca_2d':
draw_pca_2d(f'Result/{project_name}/{file_name}.xlsx')
print('Save successed to ' + f'Result/{project_name}/{file_name}.png')
elif file_name == 'pca_3d':
draw_pca_3d(f'Result/{project_name}/{file_name}.xlsx')
print('Save successed to ' + f'Result/{project_name}/{file_name}.png')
elif file_name == 'acc_and_loss':
draw_epoch_data(f'Result/{project_name}/{file_name}.xlsx')
draw_last_epoch_bar_chart(f'Result/{project_name}/{file_name}.xlsx')
print('Save successed to line graph and bar graph')
elif file_name == 'cm':
draw_and_save_cm(f'Result/{project_name}/{file_name}.xlsx')
print('Save successed cm')
elif file_name == 'cmn':
draw_and_save_cm(f'Result/{project_name}/{file_name}.xlsx')
print('Save successed cmn')
else:
print('unknow picture type')
def draw_pca_2d(file_path):
df = pd.read_excel(file_path)
plt.figure(figsize=(8, 6))
plt.scatter(df['PC1'], df['PC2'], c=df['labels'], cmap='viridis', edgecolor='k', alpha=0.6)
plt.xlabel('PC1')
plt.ylabel('PC2')
plt.title('2D PCA')
plt.colorbar(label='Labels')
plt.savefig(file_path.replace('.xlsx', '.png'))
plt.close()
def draw_pca_3d(file_path):
df = pd.read_excel(file_path)
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(df['PC1'], df['PC2'], df['PC3'], c=df['labels'], cmap='viridis', edgecolor='k', alpha=0.6)
ax.set_xlabel('PC1')
ax.set_ylabel('PC2')
ax.set_zlabel('PC3')
ax.set_title('3D PCA')
fig.colorbar(scatter, ax=ax, label='Labels')
plt.savefig(file_path.replace('.xlsx', '.png'))
def draw_epoch_data(file_path):
df = pd.read_excel(file_path)
epochs = df['epoch']
train_loss = df['train_loss']
train_accuracy = df['train_accuracy'] * 100
test_accuracy = df['test_accuracy'] * 100
f1_score = df['f1_score']
precision = df['precision']
recall = df['recall']
fig, axs = plt.subplots(2, 3, figsize=(18, 12))
axs[0, 0].plot(epochs, train_loss, 'b-', label='Train Loss')
axs[0, 0].set_xlabel('Epoch')
axs[0, 0].set_ylabel('Loss')
axs[0, 0].set_title('Training Loss over Epochs')
axs[0, 0].legend()
axs[0, 1].plot(epochs, train_accuracy, 'g-', label='Train Accuracy')
axs[0, 1].plot(epochs, test_accuracy, 'r-', label='Test Accuracy')
axs[0, 1].set_xlabel('Epoch')
axs[0, 1].set_ylabel('Accuracy (%)')
axs[0, 1].set_title('Train and Test Accuracy over Epochs')
axs[0, 1].legend()
axs[0, 2].plot(epochs, f1_score, 'm-', label='F1 Score')
axs[0, 2].set_xlabel('Epoch')
axs[0, 2].set_ylabel('F1 Score')
axs[0, 2].set_title('F1 Score over Epochs')
axs[0, 2].legend()
axs[1, 0].plot(epochs, precision, 'c-', label='Precision')
axs[1, 0].set_xlabel('Epoch')
axs[1, 0].set_ylabel('Precision')
axs[1, 0].set_title('Precision over Epochs')
axs[1, 0].legend()
axs[1, 1].plot(epochs, recall, 'y-', label='Recall')
axs[1, 1].set_xlabel('Epoch')
axs[1, 1].set_ylabel('Recall')
axs[1, 1].set_title('Recall over Epochs')
axs[1, 1].legend()
axs[1, 2].axis('off')
plt.tight_layout()
plt.savefig(file_path.replace('.xlsx', '_epoch.png'))
plt.close()
def draw_last_epoch_bar_chart(file_path):
df = pd.read_excel(file_path)
last_epoch_data = df.iloc[-1]
metrics = ['train_loss', 'train_accuracy', 'test_accuracy', 'f1_score', 'precision', 'recall']
values = [last_epoch_data[metric] for metric in metrics]
labels = ['Train Loss', 'Train Accuracy', 'Test Accuracy', 'F1 Score', 'Precision', 'Recall']
values[1] *= 100
values[2] *= 100
plt.figure(figsize=(10, 6))
plt.bar(labels, values, color=['blue', 'green', 'red', 'magenta', 'cyan', 'yellow'])
plt.xlabel('Metrics')
plt.ylabel('Values')
plt.title('Last Epoch Metrics')
plt.ylim(bottom=0)
for i, value in enumerate(values):
plt.text(i, value + 0.01, f'{value:.2f}', ha='center')
plt.tight_layout()
plt.savefig(file_path.replace('.xlsx', '_last_epoch_bar.png'))
plt.close()
def draw_and_save_cm(file_path):
df_cm = pd.read_excel(file_path)
labels = df_cm.columns[1:].tolist()
cm = df_cm.iloc[:, 1:].to_numpy(dtype=float)
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(cm, interpolation='nearest', cmap='Blues')
axs[0].set_title('Confusion Matrix')
axs[0].set_xlabel('Predicted')
axs[0].set_ylabel('True')
axs[0].set_xticks(np.arange(len(labels)))
axs[0].set_yticks(np.arange(len(labels)))
axs[0].set_xticklabels(labels)
axs[0].set_yticklabels(labels)
for i in range(len(labels)):
for j in range(len(labels)):
axs[0].text(j, i, f'{cm[i, j]}', ha='center', va='center')
plt.tight_layout()
plt.savefig(file_path.replace('.xlsx', '.png'))
plt.close()

View File

@ -1,96 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
from Qtorch.Models.Qnn import Qnn
class QCNN(Qnn):
def __init__(
self,
data,
labels=None,
conv_channels=(16, 32),
kernel_size=3,
hidden_size=128,
dropout_rate=0.3,
test_size=0.2,
random_state=None,
batch_size=64,
learning_rate=0.00001,
weight_decay=1e-5,
lr_scheduler_patience=10,
early_stop_patience=100,
early_stop_threshold=0.99,
):
super(QCNN, self).__init__(
data=data,
labels=labels,
test_size=test_size,
random_state=random_state,
batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=weight_decay,
lr_scheduler_patience=lr_scheduler_patience,
early_stop_patience=early_stop_patience,
early_stop_threshold=early_stop_threshold,
)
self.conv_channels = tuple(conv_channels)
self.kernel_size = kernel_size
self.hidden_size = hidden_size
self.dropout_rate = dropout_rate
self.feature_extractor = nn.Sequential()
self.classifier = nn.Sequential()
# 构造 1D CNN 网络结构
self.build_model(input_shape=self.X_train.shape[1:], num_classes=self.num_classes)
self._model_built = True
def _transform_features(self, features):
# 1D CNN 输入格式: [batch, channel=1, length]
return torch.tensor(features, dtype=torch.float32).unsqueeze(1)
def build_model(self, input_shape, num_classes):
if len(self.conv_channels) == 0:
raise ValueError("'conv_channels' must contain at least one channel size.")
input_length = int(np.prod(input_shape))
conv_layers = []
in_channels = 1
for out_channels in self.conv_channels:
conv_layers.append(nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=self.kernel_size))
conv_layers.append(nn.ReLU())
conv_layers.append(nn.MaxPool1d(kernel_size=2))
in_channels = out_channels
self.feature_extractor = nn.Sequential(*conv_layers)
conv_output_size = self._get_conv_output_size(input_length)
self.classifier = nn.Sequential(
nn.Linear(conv_output_size, self.hidden_size),
nn.ReLU(),
nn.Dropout(self.dropout_rate),
nn.Linear(self.hidden_size, num_classes),
)
self.__init_weights()
def _get_conv_output_size(self, input_length):
x = torch.randn(1, 1, input_length)
x = self.feature_extractor(x)
return int(x.numel())
def forward(self, x):
x = self.feature_extractor(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x
def __init_weights(self):
for m in self.modules():
if isinstance(m, (nn.Conv1d, nn.Linear)):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)

View File

@ -1,76 +0,0 @@
import numpy as np
import torch.nn as nn
from Qtorch.Models.Qnn import Qnn
class Qmlp(Qnn):
def __init__(self, data,
hidden_layers,
labels=None,
dropout_rate=0.3,
test_size = 0.2,
random_state=None,
batch_size=64,
learning_rate=0.00001,
weight_decay=1e-5,
lr_scheduler_patience=10,
early_stop_patience=100,
early_stop_threshold=0.99,
):
super(Qmlp, self).__init__(
data=data,
labels=labels,
test_size=test_size,
random_state=random_state,
batch_size=batch_size,
learning_rate=learning_rate,
weight_decay=weight_decay,
lr_scheduler_patience=lr_scheduler_patience,
early_stop_patience=early_stop_patience,
early_stop_threshold=early_stop_threshold,
)
self.hidden_layers = hidden_layers
self.dropout_rate = dropout_rate
self.layers = nn.ModuleList()
# 构造 MLP 网络结构
self.build_model(input_shape=self.X_train.shape[1:], num_classes=self.num_classes)
self._model_built = True
def build_model(self, input_shape, num_classes):
if not self.hidden_layers:
raise ValueError("'hidden_layers' must contain at least one layer size.")
input_size = int(np.prod(input_shape))
self.layers = nn.ModuleList()
# 连接输入层和第一个隐藏层
self.layers.append(nn.Linear(input_size, self.hidden_layers[0]))
self.layers.append(nn.BatchNorm1d(self.hidden_layers[0]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(self.dropout_rate))
# 创建隐藏层
for i in range(1, len(self.hidden_layers)):
self.layers.append(nn.Linear(self.hidden_layers[i-1], self.hidden_layers[i]))
self.layers.append(nn.BatchNorm1d(self.hidden_layers[i]))
self.layers.append(nn.ReLU())
self.layers.append(nn.Dropout(self.dropout_rate))
# 创建输出层
self.layers.append(nn.Linear(self.hidden_layers[-1], num_classes))
self.__init_weights()
def forward(self, x):
x = x.view(x.size(0), -1)
for layer in self.layers:
x = layer(x)
return x
def __init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)

View File

@ -1,233 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
from torch.utils.data import DataLoader, TensorDataset
from Qfunctions.divSet import divSet as DS
class Qnn(nn.Module):
def __init__(
self,
data,
labels,
test_size=0.2,
random_state=None,
batch_size=64,
learning_rate=0.00001,
weight_decay=1e-5,
lr_scheduler_patience=10,
early_stop_patience=100,
early_stop_threshold=0.99,
):
super(Qnn, self).__init__()
# 使用gpu进行加速, 没有gpu的话使用CPU
self.DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 训练配置,子类共享
self.batch_size = batch_size
self.learning_rate = learning_rate
self.weight_decay = weight_decay
self.lr_scheduler_patience = lr_scheduler_patience
self.early_stop_patience = early_stop_patience
self.early_stop_threshold = early_stop_threshold
# 划分测试集和训练集
self.X_train, self.X_test, self.y_train, self.y_test, self.LABEL_ENCODER = DS(
data=data, labels=labels, test_size=test_size, random_state=random_state
)
self.labels = labels
self.num_classes = len(labels) if labels is not None else int(np.max(self.y_train)) + 1
# 网络状态
self._model_built = False
# 存储过程数据
self.epoch_data = self._new_epoch_data()
# PCA 图片数据存储
self.pca_2d, self.pca_3d = None, None
self.cm, self.cmn = None, None
def _new_epoch_data(self):
return {
'epoch': [],
'train_loss': [],
'train_accuracy': [],
'test_accuracy': [],
'precision': [],
'recall': [],
'f1_score': []
}
def build_model(self, input_shape, num_classes):
# 子类必须实现具体网络结构
raise NotImplementedError("Subclasses must implement build_model(input_shape, num_classes)")
def _transform_features(self, features):
# 默认输入格式: [batch, feature_dim]
return torch.tensor(features, dtype=torch.float32)
def _prepare_data(self):
# 将data转换为tensor形式子类可覆写 _transform_features
X_train_tensor = self._transform_features(self.X_train)
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
X_test_tensor = self._transform_features(self.X_test)
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False)
return train_loader, test_loader
def _train_model(self, train_loader, test_loader, epochs_times=100):
model = self.to(self.DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.1,
patience=self.lr_scheduler_patience,
)
best_test_accuracy = 0
counter = 0
for epoch in range(epochs_times):
model.train()
running_loss = 0.0
correct_train = 0
total_train = 0
for inputs, labels in train_loader:
inputs, labels = inputs.to(self.DEVICE), labels.to(self.DEVICE)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs.data, 1)
total_train += labels.size(0)
correct_train += (predicted == labels).sum().item()
train_accuracy = correct_train / total_train
train_loss = running_loss / len(train_loader)
model.eval()
correct_test = 0
total_test = 0
all_labels = []
all_predicted = []
all_prob = []
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(self.DEVICE), labels.to(self.DEVICE)
outputs = model(inputs)
prob = torch.nn.functional.softmax(outputs, dim=1)
_, predicted = torch.max(outputs.data, 1)
total_test += labels.size(0)
correct_test += (predicted == labels).sum().item()
all_labels.extend(labels.cpu().numpy())
all_predicted.extend(predicted.cpu().numpy())
all_prob.extend(prob.cpu().numpy())
test_accuracy = correct_test / total_test
f1 = f1_score(all_labels, all_predicted, average='macro', zero_division=0)
precision = precision_score(all_labels, all_predicted, average='macro', zero_division=0)
recall = recall_score(all_labels, all_predicted, average='macro', zero_division=0)
if (epoch + 1) % 10 == 0:
print('===============================================')
print(f'Epoch [{epoch + 1} / {epochs_times}]:')
print(f'Train Accuracy: {train_accuracy * 100:.2f}%, Test Accuracy: {test_accuracy*100:.2f}%, Loss: {train_loss:.4f}')
print(f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score:{f1:.4f}, ')
print('===============================================')
self.epoch_data['epoch'].append(epoch+1)
self.epoch_data['train_loss'].append(train_loss)
self.epoch_data['train_accuracy'].append(train_accuracy)
self.epoch_data['test_accuracy'].append(test_accuracy)
self.epoch_data['precision'].append(precision)
self.epoch_data['recall'].append(recall)
self.epoch_data['f1_score'].append(f1)
scheduler.step(train_loss)
if test_accuracy > best_test_accuracy:
best_test_accuracy = test_accuracy
counter = 0
else:
counter += 1
if counter >= self.early_stop_patience and best_test_accuracy >= self.early_stop_threshold:
print(f"Early stopping at epoch {epoch+1}")
break
# cmn为归一化矩阵
# Keep matrix dimensions stable even when some classes do not appear in this split.
cm_labels = np.arange(len(self.labels)) if self.labels is not None else None
self.cm = confusion_matrix(all_labels, all_predicted, labels=cm_labels)
self.cmn = confusion_matrix(all_labels, all_predicted, labels=cm_labels, normalize='true')
print(self.cm)
return
def fit(self, epoch_times = 100):
if not self._model_built:
self.build_model(input_shape=self.X_train.shape[1:], num_classes=self.num_classes)
self._model_built = True
# 每次训练前清空过程指标,避免重复累计
self.epoch_data = self._new_epoch_data()
train_loader, test_loader = self._prepare_data()
self._train_model(train_loader, test_loader, epochs_times=epoch_times)
return
# 外部获取PCA图像数据的接口
def get_PCA(self):
# PCA 2D 图像
pca_2d = PCA(n_components=2) # 保留两个主成分
principalComponents = pca_2d.fit_transform(self.X_train)
df_pca2d =pd.DataFrame(data=principalComponents, columns=['PC1', 'PC2'])
df_pca2d['labels'] = self.y_train
# PCA 3D 图像
pca_3d = PCA(n_components=3) # 保留三个主成分
principalComponents = pca_3d.fit_transform(self.X_train)
df_pca3d = pd.DataFrame(data=principalComponents, columns=['PC1', 'PC2', 'PC3'])
df_pca3d['labels'] = self.y_train
return df_pca2d, df_pca3d
# 外部获取混淆矩阵的接口
def get_cm(self):
label_names = self.labels if self.labels is not None else list(range(self.num_classes))
return pd.DataFrame(self.cm, columns=label_names, index=label_names)
def get_cmn(self):
label_names = self.labels if self.labels is not None else list(range(self.num_classes))
return pd.DataFrame(self.cmn, columns=label_names, index=label_names)
# 外部获取迭代数据的接口
def get_epoch_data(self):
return pd.DataFrame(self.epoch_data)

View File

@ -1,2 +0,0 @@
# Qtorch/__init__.py
from .Models import Qnn, Qmlp, Qcnn

318
README.md
View File

@ -1,318 +0,0 @@
# Deeplearning 使用说明
## 1. Conda 环境迁移
环境文件在 `conda_env/`
- `conda_env/environment.portable.yml`:通用迁移(推荐)
- `conda_env/environment.lock.txt`:精确锁定(同系统/同架构优先)
- `conda_env/env.yml`:历史文件
### 1.1 创建环境
```bash
# 方式1推荐通用创建
conda env create -f conda_env/environment.portable.yml
conda activate Deeplearning
# 方式2精确复现
conda create -n Deeplearning --file conda_env/environment.lock.txt
conda activate Deeplearning
# 验证
python -V
python -c "import torch; print(torch.__version__)"
```
### 1.2 同名环境已存在时
```bash
# 方式A保留旧环境改名创建
conda env create -f conda_env/environment.portable.yml -n Deeplearning_v2
conda activate Deeplearning_v2
# 或者lock 方式)
conda create -n Deeplearning_v2 --file conda_env/environment.lock.txt
conda activate Deeplearning_v2
```
```bash
# 方式B删除旧环境后重建谨慎
conda env remove -n Deeplearning
conda env create -f conda_env/environment.portable.yml
conda activate Deeplearning
```
### 1.3 重新导出环境
```bash
conda env export -n Deeplearning --no-builds > conda_env/environment.portable.yml
conda list -n Deeplearning --explicit > conda_env/environment.lock.txt
```
### 1.4 主要依赖包
训练与数据处理核心依赖:
- Python 3.12
- pytorch / torchvision / torchaudio
- pandas / numpy / scipy
- scikit-learn
- matplotlib / seaborn
- openpyxl / xlrd用于 xls/xlsx 读写)
说明:项目仓库已提供完整环境文件,优先用 `conda_env/environment.portable.yml``conda_env/environment.lock.txt` 创建环境。
### 1.5 依赖安装方式Conda / pip
方式 A推荐Conda 一步到位):
```bash
conda env create -f conda_env/environment.portable.yml
conda activate Deeplearning
```
方式 BConda 最小安装,适合自定义环境):
```bash
conda create -n Deeplearning python=3.12 -y
conda activate Deeplearning
# GPU 机器CUDA 12.4
conda install -y pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
# CPU 机器(无 CUDA
# conda install -y pytorch torchvision torchaudio cpuonly -c pytorch
conda install -y pandas numpy scipy scikit-learn matplotlib seaborn openpyxl xlrd
```
方式 Cpip 安装,适合已有虚拟环境):
```bash
pip install torch torchvision torchaudio
pip install pandas numpy scipy scikit-learn matplotlib seaborn openpyxl xlrd
```
可选开发工具:
```bash
pip install black autopep8 basedpyright
```
## 2. 项目约定
### 2.1 输入数据格式
每一类数据支持 `xls/xlsx/csv`。读取时默认取偶数列(索引 1,3,5...)作为特征,奇数列内容可忽略。
示意:
| 任意值 | 特征值 | 任意值 | 特征值 |
|---|---|---|---|
| arbitrary value | value | arbitrary value | value |
### 2.2 目录约定
训练数据放在 `Static/`,输出结果放在 `Result/`
推荐目录:
```text
.
├─ Static/
│ └─ 20241009MaterialDiv/
└─ Result/
```
## 3. 快速开始
### 3.1 准备数据
1. 将数据目录命名为 `日期+项目名`,例如 `20241009MaterialDiv`
2. 准备 `label_names`(建议英文或数字)。
3. 将数据目录放入 `Static/`
### 3.2 数据目录模板
单文件模式(每个标签一个文件):
```text
Static/
20241009MaterialDiv/
Acrlic.xlsx
Ecoflex.xlsx
PDMS.xlsx
PLA.xlsx
Wood.xlsx
```
多子特征模式(每个标签一个子目录,目录下可有多个文件):
```text
Static/
20241009MaterialDiv/
Acrlic/
sample_01.xlsx
sample_02.xlsx
Ecoflex/
sample_01.xlsx
sample_02.xlsx
PDMS/
sample_01.xlsx
sample_02.xlsx
PLA/
sample_01.xlsx
sample_02.xlsx
Wood/
sample_01.xlsx
sample_02.xlsx
```
命名规则(重要):
- `label_names` 中每一项必须与文件名(单文件模式)或子文件夹名(多子特征模式)一致。
- `label_names` 顺序就是标签编码顺序,训练结果和混淆矩阵按该顺序展示。
示例:
```python
label_names = ['Acrlic', 'Ecoflex', 'PDMS', 'PLA', 'Wood']
```
对应关系:
```text
Acrlic <-> Acrlic.xlsx 或 Acrlic/
Ecoflex <-> Ecoflex.xlsx 或 Ecoflex/
PDMS <-> PDMS.xlsx 或 PDMS/
PLA <-> PLA.xlsx 或 PLA/
Wood <-> Wood.xlsx 或 Wood/
```
### 3.3 通用:数据导入
```python
from Qfunctions.loadData import load_data
projet_name = '20241009MaterialDiv'
label_names = ['Acrlic', 'Ecoflex', 'PDMS', 'PLA', 'Wood']
# 自动识别数据模式(支持 xls/xlsx/csv
data = load_data(projet_name, label_names)
```
### 3.4 模型调用
#### 3.4.1 MLP
```python
from Qtorch.Models.Qmlp import Qmlp
model = Qmlp(
data=data,
labels=label_names,
hidden_layers=[128, 256, 128],
test_size=0.3,
dropout_rate=0,
)
model.fit(300)
```
#### 3.4.2 1D CNN
```python
from Qtorch.Models.Qcnn import QCNN
model = QCNN(
data=data,
labels=label_names,
conv_channels=(16, 32),
kernel_size=3,
hidden_size=128,
test_size=0.3,
dropout_rate=0,
)
model.fit(300)
```
### 3.5 通用:结果获取与图表导出
```python
from Qfunctions.saveToXlsx import save_to_xlsx
pca_2d, pca_3d = model.get_PCA()
cm = model.get_cm()
cmn = model.get_cmn()
epoch_data = model.get_epoch_data()
save_to_xlsx(project_name=projet_name, file_name='pca_2d', data=pca_2d)
save_to_xlsx(project_name=projet_name, file_name='pca_3d', data=pca_3d)
save_to_xlsx(project_name=projet_name, file_name='cm', data=cm)
save_to_xlsx(project_name=projet_name, file_name='cmn', data=cmn)
save_to_xlsx(project_name=projet_name, file_name='acc_and_loss', data=epoch_data)
```
### 3.6 每次运行参数与结果 YAML 存档
当前 `main.py` 会为每次运行创建时间戳目录:
- `Result/<project_name>/<YYYYMMDD_HHMMSS>/`
并在目录下自动保存:
- `run_params.yaml`:本次运行参数快照。
- `run_result.yaml`:本次运行结果摘要(最佳轮次与最后轮次指标)。
这样可以直接对比不同运行的参数与结果变化。
## 4. load_data 参数说明
| 参数 | 类型 | 默认值 | 说明 |
|---|---|---|---|
| folder | str | 必填 | `Static/` 下的数据目录名 |
| labelNames | list | 必填 | 类别名称列表,用于读取和排序标签 |
自动识别规则:
- 若每个 `label` 都对应 `folder/label/*.(xlsx|xls|csv)`,识别为多子特征模式。
- 若每个 `label` 都对应 `folder/label.(xlsx|xls|csv)`,识别为单文件模式。
- 超出支持范围的文件格式(仅支持 xls/xlsx/csv会报错。
- 若两种都成立(同名文件和同名子目录同时存在),会报错并提示只保留一种目录结构。
- 若两种都不成立,会报错并提示检查目录结构或 `label_names`
读取路径规则:
- 单文件模式:`./Static/folder/labelNames[i].(xlsx|xls|csv)`
- 多子特征模式:`./Static/folder/labelNames[i]/*.(xlsx|xls|csv)`
## 5. 常见问题
### 5.1 找不到文件
优先检查:
- `label_names` 与文件/文件夹是否同名
- 文件后缀是否为 `.xls`、`.xlsx` 或 `.csv`(其他格式将报错)
## 6. TODO后续计划
### 阶段一:基础稳定与可维护性
- [ ] 固化模型基类契约:`build_model`、输入变换钩子、统一训练配置。
- [ ] 封装 `QDL` 包结构(`qdl.data / qdl.models / qdl.export / qdl.api`)。
- [ ] 增加兼容层:保留旧导入路径,逐步迁移到新包路径。
- [ ] 增加最小测试集:`load_data`、`Qmlp.fit(1)`、`QCNN.fit(1)`、导出函数。
### 阶段二:高维与复合模型能力
- [ ] 将 batch 结构升级为字典(`x/y/lengths/mask/meta`),支持复合输入。
- [ ] 在训练框架中加入钩子:`prepare_batch`、`forward_step`、`compute_loss`。
- [ ] 引入编码器层Encoder抽象按策略扩展MLP/CNN/LSTM/GNN/Transformer不按维度硬扩类。
- [ ] 支持可变长时间序列(`collate_fn + lengths + mask`),作为后续复合模型基础。
- [ ] 增加多分支融合模板(时序分支 + 静态分支),预留多任务损失组合。
### 阶段三:对外能力与发布
- [ ] 统一对外入口:提供高层 API例如 `train_mlp`、`train_cnn1d`、`train_hybrid`)。
- [ ] 在模型文档中预留扩展位:`3.4.3 LSTM`、`3.4.4 GNN`、`3.4.5 Hybrid`。
- [ ] 完成打包配置(`pyproject.toml`)与本地可编辑安装说明。
- [ ] 发布前回归:在 Conda 与 pip 环境分别跑通最小端到端流程。

View File

@ -1,49 +0,0 @@
#!/bin/bash
# 输出的Markdown文件名
output_file="python_files_output.md"
# 清空或创建输出文件
> "$output_file"
# 递归函数来处理目录
process_directory() {
local dir="$1"
local depth="$2"
# 添加目录作为标题
echo "$(printf '%0.s#' $(seq 1 $depth)) ${dir##*/}" >> "$output_file"
echo "" >> "$output_file"
# 遍历当前目录中的所有.py文件
for file in "$dir"/*.py; do
# 检查文件是否存在(以防止没有.py文件的情况
if [ -f "$file" ]; then
# 将文件名作为子标题写入md文件
echo "$(printf '%0.s#' $(seq 1 $((depth + 1)))) ${file##*/}" >> "$output_file"
echo "" >> "$output_file"
# 添加代码块开始标记
echo '```python' >> "$output_file"
# 将Python文件内容添加到md文件
cat "$file" >> "$output_file"
# 添加代码块结束标记
echo '```' >> "$output_file"
echo "" >> "$output_file"
fi
done
# 递归处理子目录
for subdir in "$dir"/*/; do
if [ -d "$subdir" ]; then
process_directory "$subdir" $((depth + 1))
fi
done
}
# 从当前目录开始处理
process_directory "." 1
echo "Markdown文件已生成: $output_file"

View File

@ -1,631 +0,0 @@
#!/usr/bin/env python3
"""
数据质量检查脚本
===================
Static/ 下的数据目录执行完整性统计平衡性与离群值检查
生成详细报告输出到终端并可保存为文本文件
用法:
python Scripts/check_data.py --folder 20260319Numbers --labels 0 1 2 3 4 5 6 7 8 9
python Scripts/check_data.py --folder "20260408 grap" --labels 1 2 3 4 5 6 7 8 9 --output report.txt
python Scripts/check_data.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
要求:
在项目根目录 (Deeplearning/) 下运行或通过 --root 指定项目根目录
"""
import os
import sys
import argparse
import unicodedata
from pathlib import Path
import numpy as np
import pandas as pd
# ============================================================
# 工具函数(与 loadData.py 中逻辑保持一致)
# ============================================================
DEFAULT_FILE_CLASSES = ("xlsx", "xls", "csv")
def _has_supported_extension(filename: str, file_classes=DEFAULT_FILE_CLASSES) -> bool:
ext = os.path.splitext(filename)[1].lower().lstrip(".")
return ext in file_classes
def _read_data_file(file_path: str) -> pd.DataFrame:
ext = os.path.splitext(file_path)[1].lower()
if ext == ".csv":
return pd.read_csv(file_path)
if ext in (".xls", ".xlsx"):
return pd.read_excel(file_path)
raise ValueError(
f"Unsupported file format: {ext}. Only .xls, .xlsx, and .csv are supported. "
f"File: {file_path}"
)
def _strip_zero_width(s: str) -> str:
if not isinstance(s, str):
return s
return s.translate(
{0x200B: None, 0x200C: None, 0x200D: None, 0xFEFF: None}
)
def _canonicalize_name(name: str) -> str:
name = unicodedata.normalize("NFKC", name)
name = _strip_zero_width(name)
return name
def _normalize_for_compare(name: str) -> str:
n = _canonicalize_name(name)
n = n.replace("_", " ")
n = " ".join(n.split())
return n.lower()
def _find_matching_file(folder: str, expected_name: str):
expected = _canonicalize_name(expected_name)
try:
entries = os.listdir(folder)
except FileNotFoundError:
return None
for f in entries:
if _canonicalize_name(f) == expected:
return f
expected_lower = expected.lower()
for f in entries:
if _canonicalize_name(f).lower() == expected_lower:
return f
expected_relaxed = _normalize_for_compare(expected_name)
for f in entries:
if _normalize_for_compare(f) == expected_relaxed:
return f
return None
def _find_matching_file_by_label(folder: str, label_name, file_classes):
for ext in file_classes:
expected_name = f"{label_name}.{ext}"
match = _find_matching_file(folder, expected_name)
if match is not None:
return match
return None
# ============================================================
# 报告生成工具
# ============================================================
class ReportBuffer:
"""收集报告行,并同时输出到 stdout 和文件。"""
def __init__(self, output_path=None):
self.lines: list[str] = []
self.output_path = output_path
def add(self, text: str = ""):
print(text)
self.lines.append(text)
def save(self):
if self.output_path:
with open(self.output_path, "w", encoding="utf-8") as f:
f.write("\n".join(self.lines) + "\n")
print(f"\n报告已保存到: {self.output_path}")
# ============================================================
# 核心检查逻辑
# ============================================================
def _extract_features(df: pd.DataFrame, source: str) -> pd.DataFrame:
"""
按项目约定提取偶数列作为特征保持 int 列名以对齐
返回特征 DataFrame列名 0, 2, 4, ...
"""
# 偶数列索引: 1, 3, 5, ...
even_cols = [c for i, c in enumerate(df.columns) if i % 2 == 1]
if not even_cols:
raise ValueError(f"没有找到偶数列(特征列)。请检查文件: {source}")
features = df[even_cols].copy()
# 尝试转为数值
for c in features.columns:
features[c] = pd.to_numeric(features[c], errors="coerce")
return features
def check_tabular_project(root: str, folder: str, labels: list[str], rp: ReportBuffer):
"""完整检查流程"""
data_dir = os.path.join(root, "Static", folder)
if not os.path.isdir(data_dir):
rp.add(f"[ERROR] 目录不存在: {data_dir}")
rp.add("请确认 --folder 参数正确。")
return
rp.add("=" * 64)
rp.add(" Deeplearning 数据质量检查报告")
rp.add("=" * 64)
rp.add(f" 数据目录 : {data_dir}")
rp.add(f" 标签数量 : {len(labels)}")
rp.add(f" 标签列表 : {labels}")
rp.add()
# ---- 第一步:检测数据模式 ----
has_all_subfolders = True
for lbl in labels:
sub = os.path.join(data_dir, str(lbl))
if not (os.path.isdir(sub) and any(_has_supported_extension(f) for f in os.listdir(sub))):
has_all_subfolders = False
break
has_all_files = True
for lbl in labels:
if _find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES) is None:
has_all_files = False
break
if has_all_files and not has_all_subfolders:
mode = "single_file"
elif has_all_subfolders and not has_all_files:
mode = "multi_folder"
else:
rp.add("[ERROR] 无法自动检测数据模式,或两种模式同时存在。")
rp.add(f" has_all_files : {has_all_files}")
rp.add(f" has_all_subfolders: {has_all_subfolders}")
rp.add("请确保每个 label 对应唯一的文件或唯一的子目录。")
return
if mode == "single_file":
_check_single_file_mode(data_dir, labels, rp)
else:
_check_multi_folder_mode(data_dir, labels, rp)
rp.add()
rp.add("=" * 64)
rp.add(" 检查完成。")
rp.add("=" * 64)
def _check_single_file_mode(data_dir: str, labels: list[str], rp: ReportBuffer):
rp.add()
rp.add("── 数据模式: 单文件模式 ──")
rp.add()
# 1. 定位实际文件名
file_map: dict[str, str] = {}
missing: list[str] = []
for lbl in labels:
match = _find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES)
if match:
file_map[lbl] = match
else:
missing.append(lbl)
if missing:
rp.add(f"[WARN] 以下标签找不到对应文件: {missing}")
rp.add(f"当前目录内容: {sorted(os.listdir(data_dir))}")
if not file_map:
return
labels = [l for l in labels if l in file_map]
# 2. 逐类读取
all_features = [] # list of (label, pd.DataFrame)
per_class_info: dict[str, dict] = {}
col_counts: dict[str, int] = {}
for lbl in labels:
fname = file_map[lbl]
file_path = os.path.join(data_dir, fname)
info: dict[str, object] = {"label": lbl, "file": fname, "warnings": []}
try:
raw = _read_data_file(file_path)
except Exception as e:
info["error"] = str(e)
per_class_info[lbl] = info
rp.add(f"[ERROR] 读取文件失败: {file_path}{e}")
continue
info["raw_rows"] = raw.shape[0]
info["raw_cols"] = raw.shape[1]
# NaN 在原始文件中
total_nan = raw.isna().sum().sum()
if total_nan > 0:
info["warnings"].append(f"原始文件含 {total_nan} 个 NaN 单元格")
try:
features = _extract_features(raw, fname)
except ValueError as e:
info["error"] = str(e)
per_class_info[lbl] = info
rp.add(f"[ERROR] 特征提取失败: {file_path}{e}")
continue
# 丢弃含 NaN 的行(同 loadData 的 dropna 逻辑)后统计
clean = features.dropna()
info["feature_cols"] = features.shape[1]
info["samples_after_dropna"] = clean.shape[0]
info["dropped_nan_rows"] = features.shape[0] - clean.shape[0]
info["values"] = clean.values
col_counts[lbl] = features.shape[1]
if clean.shape[0] == 0:
info["warnings"].append("去除 NaN 后无有效样本")
per_class_info[lbl] = info
if clean.shape[0] > 0:
all_features.append((lbl, clean))
# 列数一致性
if len(set(col_counts.values())) > 1:
rp.add()
rp.add("[WARN] 各标签的特征列数不一致!")
for lbl, cc in col_counts.items():
rp.add(f" {lbl}: {cc}")
rp.add("这会导致 load_data 时补零逻辑产生差异。")
else:
rp.add(f"[OK] 所有标签特征列数一致: {next(iter(col_counts.values()), 0)}")
# 样本数统计
rp.add()
rp.add("── 各类别样本数 ──")
sample_counts: dict[str, int] = {}
for lbl in labels:
info = per_class_info.get(lbl, {})
if "error" in info:
rp.add(f" [{lbl}] 加载失败: {info['error']}")
continue
n = info.get("samples_after_dropna", 0)
sample_counts[lbl] = n
warnings = info.get("warnings", [])
wflag = f"{'; '.join(warnings)}" if warnings else ""
rp.add(f" [{lbl}] {n} 行 (文件: {info.get('file','?')}, "
f"原始 {info.get('raw_rows','?')} 行, "
f"丢弃 NaN 行 {info.get('dropped_nan_rows',0)}){wflag}")
# 平衡性分析
_analyze_balance(sample_counts, rp)
# 统计 + 离群值
_analyze_statistics(all_features, rp)
_analyze_outliers(all_features, rp)
def _check_multi_folder_mode(data_dir: str, labels: list[str], rp: ReportBuffer):
rp.add()
rp.add("── 数据模式: 多子特征模式 ──")
rp.add()
all_features = []
per_class_info: dict[str, dict] = {}
col_counts: dict[str, int] = {}
for lbl in labels:
sub = os.path.join(data_dir, str(lbl))
if not os.path.isdir(sub):
per_class_info[lbl] = {"error": f"子目录不存在: {sub}"}
rp.add(f"[ERROR] {lbl}: 子目录不存在")
continue
files = sorted(
[f for f in os.listdir(sub) if _has_supported_extension(f)]
)
if not files:
per_class_info[lbl] = {"error": f"子目录下无支持的文件: {sub}"}
rp.add(f"[ERROR] {lbl}: 子目录下无 .xlsx/.xls/.csv 文件")
continue
class_frame_list = []
single_file_cols = set()
total_raw = 0
total_dropped = 0
failed_files = []
for fname in files:
file_path = os.path.join(sub, fname)
try:
raw = _read_data_file(file_path)
except Exception as e:
failed_files.append(f" {fname}: {e}")
continue
total_raw += raw.shape[0]
try:
features = _extract_features(raw, f"{lbl}/{fname}")
except ValueError as e:
failed_files.append(f" {fname}: {e}")
continue
single_file_cols.add(features.shape[1])
clean = features.dropna()
total_dropped += features.shape[0] - clean.shape[0]
if clean.shape[0] > 0:
class_frame_list.append(clean)
info: dict[str, object] = {
"label": lbl,
"num_files": len(files),
"raw_rows_total": total_raw,
"dropped_nan_rows": total_dropped,
"warnings": [],
}
if failed_files:
info["warnings"].append(f"{len(failed_files)} 个文件加载失败")
for ff in failed_files:
rp.add(f" [WARN] {ff}")
if len(single_file_cols) > 1:
info["warnings"].append(
f"子文件间列数不一致: {sorted(single_file_cols)}"
)
col_counts[lbl] = max(single_file_cols)
elif single_file_cols:
col_counts[lbl] = single_file_cols.pop()
else:
col_counts[lbl] = 0
if class_frame_list:
combined = pd.concat(class_frame_list, ignore_index=True)
info["samples_after_dropna"] = combined.shape[0]
info["feature_cols"] = combined.shape[1]
info["values"] = combined.values
all_features.append((lbl, combined))
else:
info["samples_after_dropna"] = 0
info["warnings"].append("无有效样本")
per_class_info[lbl] = info
# 列数一致性
non_zero = {l: c for l, c in col_counts.items() if c > 0}
if non_zero and len(set(non_zero.values())) > 1:
rp.add()
rp.add("[WARN] 各标签的特征列数不一致(将使用零填充对齐):")
for lbl, cc in col_counts.items():
rp.add(f" {lbl}: {cc}")
elif non_zero:
rp.add(f"[OK] 所有标签特征列数一致: {next(iter(non_zero.values()))}")
# 样本数统计
rp.add()
rp.add("── 各类别样本数 ──")
sample_counts: dict[str, int] = {}
for lbl in labels:
info = per_class_info.get(lbl, {})
if "error" in info:
rp.add(f" [{lbl}] 加载失败: {info['error']}")
continue
n = info.get("samples_after_dropna", 0)
sample_counts[lbl] = n
wflag = ""
if info.get("warnings"):
wflag = f"{'; '.join(info['warnings'])}"
rp.add(f" [{lbl}] {n}"
f"(来自 {info.get('num_files','?')} 个文件, "
f"原始 {info.get('raw_rows_total','?')} 行, "
f"丢弃 NaN 行 {info.get('dropped_nan_rows',0)}){wflag}")
_analyze_balance(sample_counts, rp)
_analyze_statistics(all_features, rp)
_analyze_outliers(all_features, rp)
# ============================================================
# 分析子模块
# ============================================================
def _analyze_balance(counts: dict[str, int], rp: ReportBuffer):
rp.add()
rp.add("── 类别平衡性分析 ──")
if not counts:
rp.add("无有效样本,跳过。")
return
values = list(counts.values())
total = sum(values)
n_classes = len(values)
avg = total / n_classes if n_classes else 0
min_count = min(values)
max_count = max(values)
rp.add(f" 总样本数 : {total}")
rp.add(f" 类别数 : {n_classes}")
rp.add(f" 平均每类 : {avg:.1f}")
rp.add(f" 最少样本类 : {min(counts, key=counts.get)} ({min_count})")
rp.add(f" 最多样本类 : {max(counts, key=counts.get)} ({max_count})")
if min_count == 0:
rp.add(" [ERROR] 存在样本数为 0 的类别,训练将无法进行!")
return
ratio = max_count / min_count if min_count > 0 else float("inf")
rp.add(f" 不平衡比例 : {ratio:.2f}:1 (max/min)")
std_val = float(np.std(values))
cv = std_val / avg if avg > 0 else 0
rp.add(f" 变异系数(CV): {cv:.4f}")
if ratio > 5:
rp.add(" [WARN] 类别严重不平衡 (>5:1),建议进行数据增强或使用类别权重。")
elif ratio > 3:
rp.add(" [WARN] 类别较不平衡 (>3:1),可考虑采样策略。")
else:
rp.add(" [OK] 类别基本平衡。")
# 训练-测试划分预估
rp.add()
rp.add(" ── 训练/测试划分预估 (test_size=0.2, stratify) ──")
for lbl, cnt in sorted(counts.items()):
test_n = max(1, int(cnt * 0.2))
train_n = cnt - test_n
rp.add(f" [{lbl}] 训练 {train_n} / 测试 {test_n} (总计 {cnt})")
def _analyze_statistics(
all_features: list[tuple[str, pd.DataFrame]],
rp: ReportBuffer,
):
rp.add()
rp.add("── 特征统计信息 ──")
if not all_features:
rp.add("无有效数据,跳过。")
return
# 全局特征统计(将各 array 零填充到相同列数)
max_cols = max(f[1].shape[1] for f in all_features)
padded_arrays = []
for _, df in all_features:
val = df.values
if val.shape[1] < max_cols:
pad = np.zeros((val.shape[0], max_cols - val.shape[1]), dtype=val.dtype)
val = np.hstack([val, pad])
padded_arrays.append(val)
all_values = np.vstack(padded_arrays)
rp.add(f" 全局特征维度: {all_values.shape} (样本数 × 特征数, 零填充对齐到 {max_cols} 列)")
rp.add(f" 全局均值 : {np.mean(all_values):.6f}")
rp.add(f" 全局标准差 : {np.std(all_values):.6f}")
rp.add(f" 全局最小值 : {np.min(all_values):.6f}")
rp.add(f" 全局最大值 : {np.max(all_values):.6f}")
rp.add(f" 全局中位数 : {np.median(all_values):.6f}")
# 每个特征维度的统计
rp.add()
n_cols = min(all_values.shape[1], 12)
rp.add(f" ── 前 {n_cols} 个特征维度的分布 (均值 ± 标准差) ──")
for j in range(n_cols):
col = all_values[:, j]
rp.add(
f" feature{j+1}: "
f"μ={np.mean(col):.4f} σ={np.std(col):.4f} "
f"[{np.min(col):.4f}, {np.max(col):.4f}] "
f"med={np.median(col):.4f}"
)
# 每个类别的简要统计
rp.add()
rp.add(" ── 各类别特征统计 ──")
for lbl, df in all_features:
val = df.values
rp.add(
f" [{lbl}] "
f"μ={np.mean(val):.4f} σ={np.std(val):.4f} "
f"范围 [{np.min(val):.4f}, {np.max(val):.4f}]"
)
def _analyze_outliers(
all_features: list[tuple[str, pd.DataFrame]],
rp: ReportBuffer,
):
rp.add()
rp.add("── 离群值检测 (基于 IQR) ──")
if not all_features:
rp.add("无有效数据,跳过。")
return
total_outlier_samples = 0
total_samples = 0
for lbl, df in all_features:
val = df.values
total_samples += val.shape[0]
q1 = np.percentile(val, 25, axis=0)
q3 = np.percentile(val, 75, axis=0)
iqr = q3 - q1
lower = q1 - 1.5 * iqr
upper = q3 + 1.5 * iqr
outlier_mask = np.any((val < lower) | (val > upper), axis=1)
n_outliers = int(np.sum(outlier_mask))
total_outlier_samples += n_outliers
pct = n_outliers / val.shape[0] * 100 if val.shape[0] > 0 else 0
status = "[OK]" if pct < 10 else "[WARN]" if pct < 25 else "[ERROR]"
rp.add(f" [{lbl}] 离群样本: {n_outliers}/{val.shape[0]} ({pct:.1f}%) {status}")
overall_pct = total_outlier_samples / total_samples * 100 if total_samples > 0 else 0
rp.add()
rp.add(f" 整体离群比例: {total_outlier_samples}/{total_samples} ({overall_pct:.1f}%)")
if overall_pct > 20:
rp.add(" [WARN] 超过 20% 数据为离群值,请确认数据清洗是否正确。")
elif overall_pct > 10:
rp.add(" [INFO] 离群值比例偏高,训练时可能影响收敛。")
else:
rp.add(" [OK] 离群值比例正常。")
# ============================================================
# 命令行入口
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="Deeplearning 数据质量检查脚本",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python Scripts/check_data.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
python Scripts/check_data.py -f "20260408 grap" -l 1 2 3 4 5 6 7 8 9 -o report.txt
python Scripts/check_data.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9 -r /path/to/project
""",
)
parser.add_argument(
"-f", "--folder",
required=True,
help="Static/ 下的数据目录名(例如 20260319Numbers",
)
parser.add_argument(
"-l", "--labels",
nargs="+",
required=True,
help="类别标签列表,空格分隔(例如 0 1 2 3 4 或 A B C",
)
parser.add_argument(
"-o", "--output",
default=None,
help="将报告保存到指定文件路径",
)
parser.add_argument(
"-r", "--root",
default=None,
help="项目根目录(默认为脚本的上级目录,即 Deeplearning/",
)
args = parser.parse_args()
# 确定项目根目录
if args.root:
root = args.root
else:
script_dir = Path(__file__).resolve().parent
root = str(script_dir.parent)
root = os.path.abspath(root)
if not os.path.isdir(root):
print(f"[ERROR] 项目根目录不存在: {root}")
sys.exit(1)
rp = ReportBuffer(output_path=args.output)
check_tabular_project(
root=root,
folder=args.folder,
labels=args.labels,
rp=rp,
)
rp.save()
if __name__ == "__main__":
main()

View File

@ -1,171 +0,0 @@
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def draw_diagram():
# Setup the figure
fig, ax = plt.subplots(figsize=(10, 8))
ax.set_xlim(0, 10)
ax.set_ylim(0, 8)
ax.axis('off') # Turn off axis
# Font settings
font_formula = {'family': 'sans-serif', 'weight': 'bold', 'size': 14}
font_label = {'family': 'sans-serif', 'weight': 'bold', 'size': 16}
# --- PART 1: TOP LEFT (The "Acrylate" Base) ---
# Replacing the SiO2 Block with an Acrylate Group structure
# We will draw a stylized Acrylic Acid molecule acting as the "Anchor"
# Label "A"
ax.text(1.0, 7.5, "A", **font_label)
# Label "Acrylate" (replaces SiO2)
ax.text(1.0, 4.5, "Acrylate\n(丙烯酸酯)", ha='center', va='center', fontsize=14, weight='bold')
# Draw Acrylate structure (Vertical orientation to mimic the wall)
# CH2=CH-C(=O)-OH
# C=C
ax.text(1.5, 6.5, "CH", **font_formula, ha='right')
ax.text(1.6, 6.4, "2", fontsize=10, weight='bold', ha='left') # subscript
ax.plot([1.5, 1.5], [6.3, 5.8], 'k-', lw=2) # Double bond line 1
ax.plot([1.4, 1.4], [6.3, 5.8], 'k-', lw=2) # Double bond line 2
ax.text(1.5, 5.6, "CH", **font_formula, ha='center')
ax.plot([1.5, 1.5], [5.4, 4.9], 'k-', lw=2) # Single bond
ax.text(1.5, 4.7, "C", **font_formula, ha='center')
# Carbonyl O
ax.plot([1.4, 1.1], [4.7, 4.7], 'k-', lw=2) # Double bond to O (sideways)
ax.plot([1.4, 1.1], [4.8, 4.8], 'k-', lw=2)
ax.text(0.9, 4.7, "O", **font_formula, ha='center', va='center')
# Hydroxyl OH (The reactive site)
ax.plot([1.6, 1.9], [4.7, 4.7], 'k-', lw=2)
ax.text(2.1, 4.7, "OH", **font_formula, ha='left', va='center')
# --- PART 2: TOP RIGHT (The Fluorinated Alcohol) ---
# Replacing FDTES with Perfluoroethylethanol (C2F5-CH2-CH2-OH)
# Structure: HO - CH2 - CH2 - CF2 - CF3
# The Plus Sign
ax.text(3.5, 5.5, "+", fontsize=40, weight='bold', color='#0070C0', ha='center')
# Start coordinates for alcohol
start_x = 4.5
y_level = 5.5
# HO-
ax.text(start_x, y_level, "HO", **font_formula, ha='right')
ax.plot([start_x + 0.1, start_x + 0.5], [y_level, y_level], 'k-', lw=2)
# -CH2-
ax.text(start_x + 0.8, y_level, "CH", **font_formula, ha='center')
ax.text(start_x + 1.05, y_level-0.1, "2", fontsize=10, weight='bold')
ax.plot([start_x + 1.2, start_x + 1.6], [y_level, y_level], 'k-', lw=2)
# -CH2-
ax.text(start_x + 1.9, y_level, "CH", **font_formula, ha='center')
ax.text(start_x + 2.15, y_level-0.1, "2", fontsize=10, weight='bold')
ax.plot([start_x + 2.3, start_x + 2.7], [y_level, y_level], 'k-', lw=2)
# -CF2- (Perfluoro group starts)
ax.text(start_x + 3.0, y_level, "C", **font_formula, ha='center')
# F on top
ax.plot([start_x + 3.0, start_x + 3.0], [y_level + 0.2, y_level + 0.5], 'k-', lw=2)
ax.text(start_x + 3.0, y_level + 0.6, "F", **font_formula, ha='center')
# F on bottom
ax.plot([start_x + 3.0, start_x + 3.0], [y_level - 0.2, y_level - 0.5], 'k-', lw=2)
ax.text(start_x + 3.0, y_level - 0.8, "F", **font_formula, ha='center')
ax.plot([start_x + 3.3, start_x + 3.7], [y_level, y_level], 'k-', lw=2)
# -CF3 (End of chain)
ax.text(start_x + 4.0, y_level, "C", **font_formula, ha='center')
# F on top
ax.plot([start_x + 4.0, start_x + 4.0], [y_level + 0.2, y_level + 0.5], 'k-', lw=2)
ax.text(start_x + 4.0, y_level + 0.6, "F", **font_formula, ha='center')
# F on bottom
ax.plot([start_x + 4.0, start_x + 4.0], [y_level - 0.2, y_level - 0.5], 'k-', lw=2)
ax.text(start_x + 4.0, y_level - 0.8, "F", **font_formula, ha='center')
# F on right
ax.plot([start_x + 4.2, start_x + 4.5], [y_level, y_level], 'k-', lw=2)
ax.text(start_x + 4.7, y_level, "F", **font_formula, ha='center')
# --- PART 3: THE ARROW ---
ax.arrow(5.0, 4.0, 0, -1.0, head_width=0.3, head_length=0.3, fc='#0070C0', ec='#0070C0', lw=3)
# --- PART 4: THE PRODUCT (Bottom) ---
# Fluorinated Acrylate Monomer
prod_y = 1.5
# Acrylate part (Left side of product)
ax.text(1.5, prod_y + 1.0, "CH", **font_formula, ha='right')
ax.text(1.6, prod_y + 0.9, "2", fontsize=10, weight='bold', ha='left')
ax.plot([1.5, 1.5], [prod_y + 0.8, prod_y + 0.3], 'k-', lw=2)
ax.plot([1.4, 1.4], [prod_y + 0.8, prod_y + 0.3], 'k-', lw=2)
ax.text(1.5, prod_y + 0.1, "CH", **font_formula, ha='center')
ax.plot([1.5, 1.5], [prod_y - 0.1, prod_y - 0.6], 'k-', lw=2)
ax.text(1.5, prod_y - 0.8, "C", **font_formula, ha='center')
# Carbonyl O
ax.plot([1.4, 1.1], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
ax.plot([1.4, 1.1], [prod_y - 0.7, prod_y - 0.7], 'k-', lw=2)
ax.text(0.9, prod_y - 0.8, "O", **font_formula, ha='center', va='center')
# Ester Oxygen (Replacing the OH group interaction)
ax.plot([1.7, 2.0], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
ax.text(2.2, prod_y - 0.8, "O", **font_formula, ha='center', va='center')
# Link to Spacer
ax.plot([2.4, 3.5], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) # Long bond to accommodate layout
# Fluorinated Chain (Right side of product)
# -CH2-
ax.text(3.8, prod_y - 0.8, "CH", **font_formula, ha='center', va='center')
ax.text(4.05, prod_y - 0.9, "2", fontsize=10, weight='bold')
ax.plot([4.2, 4.6], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
# -CH2-
ax.text(4.9, prod_y - 0.8, "CH", **font_formula, ha='center', va='center')
ax.text(5.15, prod_y - 0.9, "2", fontsize=10, weight='bold')
ax.plot([5.3, 5.7], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
# -CF2-
ax.text(6.0, prod_y - 0.8, "C", **font_formula, ha='center', va='center')
# F top/bottom
ax.plot([6.0, 6.0], [prod_y - 0.6, prod_y - 0.3], 'k-', lw=2)
ax.text(6.0, prod_y - 0.1, "F", **font_formula, ha='center')
ax.plot([6.0, 6.0], [prod_y - 1.0, prod_y - 1.3], 'k-', lw=2)
ax.text(6.0, prod_y - 1.6, "F", **font_formula, ha='center')
ax.plot([6.3, 6.7], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
# -CF3
ax.text(7.0, prod_y - 0.8, "C", **font_formula, ha='center', va='center')
# F top/bottom
ax.plot([7.0, 7.0], [prod_y - 0.6, prod_y - 0.3], 'k-', lw=2)
ax.text(7.0, prod_y - 0.1, "F", **font_formula, ha='center')
ax.plot([7.0, 7.0], [prod_y - 1.0, prod_y - 1.3], 'k-', lw=2)
ax.text(7.0, prod_y - 1.6, "F", **font_formula, ha='center')
# F right
ax.plot([7.2, 7.5], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
ax.text(7.7, prod_y - 0.8, "F", **font_formula, ha='center', va='center')
# Save the figure
plt.savefig("reaction_diagram.png", bbox_inches='tight', dpi=300)
plt.close()
if __name__ == "__main__":
draw_diagram()
print("Diagram generated as reaction_diagram.png")

View File

@ -1,657 +0,0 @@
#!/usr/bin/env python3
"""
数据可视化脚本
===============
加载 Static/ 下的数据目录生成多种可视化图表
- 类别分布柱状图
- 特征分布直方图各类叠加
- 特征箱线图 N 个特征
- PCA 降维散点图 + 置信椭圆
- t-SNE 降维散点图
- 各类别均值/标准差对比热力图
- 全局特征相关性热力图
- 全局特征分布概览
用法:
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
python Scripts/visualize.py -f "20260408 grap" -l 1 2 3 4 5 6 7 8 9 --max-features 20
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9 --no-tsne
输出目录: Visualizations/<folder>/
"""
import os
import sys
import argparse
import unicodedata
from pathlib import Path
import warnings
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
# ============================================================
# 数据加载工具(与 check_data.py 保持一致)
# ============================================================
DEFAULT_FILE_CLASSES = ("xlsx", "xls", "csv")
def _has_supported_extension(filename: str, file_classes=DEFAULT_FILE_CLASSES) -> bool:
ext = os.path.splitext(filename)[1].lower().lstrip(".")
return ext in file_classes
def _read_data_file(file_path: str) -> pd.DataFrame:
ext = os.path.splitext(file_path)[1].lower()
if ext == ".csv":
return pd.read_csv(file_path)
if ext in (".xls", ".xlsx"):
return pd.read_excel(file_path)
raise ValueError(f"Unsupported file format: {ext}: {file_path}")
def _strip_zero_width(s: str) -> str:
if not isinstance(s, str):
return s
return s.translate({0x200B: None, 0x200C: None, 0x200D: None, 0xFEFF: None})
def _canonicalize_name(name: str) -> str:
name = unicodedata.normalize("NFKC", name)
return _strip_zero_width(name)
def _normalize_for_compare(name: str) -> str:
n = _canonicalize_name(name)
n = n.replace("_", " ")
n = " ".join(n.split())
return n.lower()
def _find_matching_file(folder: str, expected_name: str):
expected = _canonicalize_name(expected_name)
try:
entries = os.listdir(folder)
except FileNotFoundError:
return None
for f in entries:
if _canonicalize_name(f) == expected:
return f
expected_lower = expected.lower()
for f in entries:
if _canonicalize_name(f).lower() == expected_lower:
return f
expected_relaxed = _normalize_for_compare(expected_name)
for f in entries:
if _normalize_for_compare(f) == expected_relaxed:
return f
return None
def _find_matching_file_by_label(folder: str, label_name, file_classes):
for ext in file_classes:
expected_name = f"{label_name}.{ext}"
match = _find_matching_file(folder, expected_name)
if match is not None:
return match
return None
def _extract_features(df: pd.DataFrame, source: str) -> pd.DataFrame:
even_cols = [c for i, c in enumerate(df.columns) if i % 2 == 1]
if not even_cols:
raise ValueError(f"没有找到偶数列(特征列)。请检查文件: {source}")
features = df[even_cols].copy()
for c in features.columns:
features[c] = pd.to_numeric(features[c], errors="coerce")
return features
# ============================================================
# 数据加载
# ============================================================
def load_all_data(root: str, folder: str, labels: list[str]):
data_dir = os.path.join(root, "Static", folder)
if not os.path.isdir(data_dir):
print(f"[ERROR] 目录不存在: {data_dir}")
sys.exit(1)
has_all_files = all(
_find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES) is not None
for lbl in labels
)
has_all_subfolders = all(
os.path.isdir(os.path.join(data_dir, str(lbl)))
and any(_has_supported_extension(f) for f in os.listdir(os.path.join(data_dir, str(lbl))))
for lbl in labels
)
if has_all_files and not has_all_subfolders:
return _load_single_file_mode(data_dir, labels)
elif has_all_subfolders and not has_all_files:
return _load_multi_folder_mode(data_dir, labels)
else:
print("[WARN] 数据模式不明确,尝试单文件模式...")
return _load_single_file_mode(data_dir, labels)
def _load_single_file_mode(data_dir: str, labels: list[str]):
all_features = []
col_counts = {}
label_names = []
for lbl in labels:
fname = _find_matching_file_by_label(data_dir, lbl, DEFAULT_FILE_CLASSES)
if fname is None:
print(f"[WARN] 标签 {lbl} 找不到文件,跳过")
continue
file_path = os.path.join(data_dir, fname)
try:
raw = _read_data_file(file_path)
except Exception as e:
print(f"[ERROR] 读取 {file_path} 失败: {e}")
continue
try:
features = _extract_features(raw, fname)
except ValueError as e:
print(f"[ERROR] {e}")
continue
clean = features.dropna()
if clean.shape[0] == 0:
print(f"[WARN] 标签 {lbl} 去除 NaN 后无样本,跳过")
continue
col_counts[lbl] = clean.shape[1]
all_features.append((lbl, clean))
label_names.append(lbl)
return _build_arrays(all_features, label_names, col_counts)
def _load_multi_folder_mode(data_dir: str, labels: list[str]):
all_features = []
col_counts = {}
label_names = []
for lbl in labels:
sub = os.path.join(data_dir, str(lbl))
if not os.path.isdir(sub):
print(f"[WARN] 标签 {lbl} 子目录不存在,跳过")
continue
files = sorted([f for f in os.listdir(sub) if _has_supported_extension(f)])
if not files:
print(f"[WARN] 标签 {lbl} 子目录无文件,跳过")
continue
frames = []
max_cols_in_class = 0
for fname in files:
file_path = os.path.join(sub, fname)
try:
raw = _read_data_file(file_path)
except Exception as e:
print(f"[WARN] 读取 {file_path} 失败: {e}")
continue
try:
feat = _extract_features(raw, f"{lbl}/{fname}")
except ValueError as e:
print(f"[WARN] {e}")
continue
clean = feat.dropna()
if clean.shape[0] > 0:
frames.append(clean)
max_cols_in_class = max(max_cols_in_class, clean.shape[1])
if not frames:
print(f"[WARN] 标签 {lbl} 无有效样本,跳过")
continue
padded_frames = []
for f in frames:
if f.shape[1] < max_cols_in_class:
pad = np.zeros((f.shape[0], max_cols_in_class - f.shape[1]))
padded = pd.DataFrame(
np.hstack([f.values, pad]),
columns=list(f.columns) + [f"_pad_{i}" for i in range(max_cols_in_class - f.shape[1])],
)
padded_frames.append(padded)
else:
padded_frames.append(f)
combined = pd.concat(padded_frames, ignore_index=True)
col_counts[lbl] = combined.shape[1]
all_features.append((lbl, combined))
label_names.append(lbl)
return _build_arrays(all_features, label_names, col_counts)
def _build_arrays(all_features, label_names, col_counts):
if not all_features:
print("[ERROR] 没有加载到任何有效数据")
sys.exit(1)
max_cols = max(c for c in col_counts.values())
X_list = []
y_list = []
for idx, (lbl, df) in enumerate(all_features):
val = df.values
if val.shape[1] < max_cols:
pad = np.zeros((val.shape[0], max_cols - val.shape[1]), dtype=val.dtype)
val = np.hstack([val, pad])
X_list.append(val)
y_list.append(np.full(val.shape[0], idx, dtype=int))
X = np.vstack(X_list)
y = np.concatenate(y_list)
return X, y, label_names, all_features, col_counts
# ============================================================
# 可视化函数
# ============================================================
TAB10 = plt.cm.tab10.colors
def _ensure_dir(path: str):
os.makedirs(path, exist_ok=True)
def plot_class_distribution(y, label_names, out_dir: str):
fig, ax = plt.subplots(figsize=(max(8, len(label_names) * 0.6), 5))
counts = [int(np.sum(y == i)) for i in range(len(label_names))]
colors = [TAB10[i % 10] for i in range(len(label_names))]
bars = ax.bar(label_names, counts, color=colors, edgecolor="white", linewidth=0.8)
for bar, cnt in zip(bars, counts):
ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + max(counts) * 0.01,
str(cnt), ha="center", va="bottom", fontsize=9)
ax.set_xlabel("类别")
ax.set_ylabel("样本数")
ax.set_title(f"类别分布 (总计 {sum(counts)} 样本, {len(label_names)} 类)")
fig.tight_layout()
path = os.path.join(out_dir, "01_class_distribution.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[OK] {path}")
def plot_feature_histograms(all_features, out_dir: str, max_features: int = 12):
n_features = min(max_features, all_features[0][1].shape[1])
n_cols = 4
n_rows = (n_features + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows))
axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
colors = [TAB10[i % 10] for i in range(len(all_features))]
for j in range(n_features):
ax = axes[j]
for idx, (lbl, df) in enumerate(all_features):
if j < df.shape[1]:
col = df.iloc[:, j].values
ax.hist(col, bins=40, density=True, alpha=0.4, color=colors[idx],
label=f"{lbl}")
ax.set_title(f"Feature {j+1}")
ax.set_xlabel("")
ax.set_ylabel("密度")
if n_features <= 8:
ax.legend(fontsize=7, loc="upper right")
for j in range(n_features, len(axes)):
axes[j].set_visible(False)
if n_features > 8:
handles = [plt.Rectangle((0, 0), 1, 1, color=colors[i], alpha=0.4)
for i in range(len(all_features))]
fig.legend(handles, [lbl for lbl, _ in all_features],
loc="lower center", ncol=min(10, len(all_features)), fontsize=7)
fig.suptitle(f"特征分布直方图(各类叠加,前 {n_features} 维)", fontsize=13, y=1.01)
fig.tight_layout()
path = os.path.join(out_dir, "02_feature_histograms.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"[OK] {path}")
def plot_feature_boxplots(all_features, out_dir: str, max_features: int = 20):
n_features = min(max_features, all_features[0][1].shape[1])
n_cols = 4
n_rows = (n_features + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(4.5 * n_cols, 3.5 * n_rows))
axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
for j in range(n_features):
ax = axes[j]
data_list = []
positions = []
labels_for_box = []
for idx, (lbl, df) in enumerate(all_features):
if j < df.shape[1]:
data_list.append(df.iloc[:, j].values)
positions.append(idx + 1)
labels_for_box.append(str(lbl))
bp = ax.boxplot(data_list, positions=positions, labels=labels_for_box,
patch_artist=True, widths=0.6, showfliers=True,
flierprops={"marker": ".", "markersize": 2, "alpha": 0.3})
for patch, idx in zip(bp["boxes"], range(len(data_list))):
patch.set_facecolor(TAB10[idx % 10])
patch.set_alpha(0.6)
ax.set_title(f"Feature {j+1}")
ax.set_xlabel("类别")
ax.tick_params(axis="x", rotation=0, labelsize=8)
for j in range(n_features, len(axes)):
axes[j].set_visible(False)
fig.suptitle(f"特征箱线图(各类别对比,前 {n_features} 维)", fontsize=13, y=1.01)
fig.tight_layout()
path = os.path.join(out_dir, "03_feature_boxplots.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"[OK] {path}")
def _plot_confidence_ellipse(ax, mean, cov, color, alpha=0.2, n_std=1.0):
from matplotlib.patches import Ellipse
vals, vecs = np.linalg.eigh(cov)
order = vals.argsort()[::-1]
vals = vals[order]
vecs = vecs[:, order]
angle = np.degrees(np.arctan2(vecs[1, 0], vecs[0, 0]))
width, height = 2 * n_std * np.sqrt(vals)
ellipse = Ellipse(xy=mean, width=width, height=height, angle=angle,
facecolor=color, alpha=alpha, edgecolor=color, linewidth=0.8)
ax.add_patch(ellipse)
def plot_pca(X, y, label_names, out_dir: str):
from sklearn.decomposition import PCA
pca = PCA(n_components=2, random_state=42)
X_pca = pca.fit_transform(X)
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
colors = [TAB10[i % 10] for i in range(len(label_names))]
# 散点图
ax = axes[0]
for i, lbl in enumerate(label_names):
mask = y == i
ax.scatter(X_pca[mask, 0], X_pca[mask, 1], c=[colors[i]], label=f"{lbl}",
alpha=0.5, s=3, edgecolors="none")
ax.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax.set_title("PCA 降维散点图")
ax.legend(fontsize=7, markerscale=3, loc="best")
# 质心+椭圆
ax2 = axes[1]
for i, lbl in enumerate(label_names):
mask = y == i
class_points = X_pca[mask]
mean = class_points.mean(axis=0)
ax2.scatter(mean[0], mean[1], c=[colors[i]], s=80, marker="X",
edgecolors="black", linewidths=0.8, zorder=5)
ax2.annotate(str(lbl), (mean[0], mean[1]), fontsize=8, ha="center", va="bottom",
fontweight="bold", xytext=(0, 4), textcoords="offset points")
if class_points.shape[0] > 2:
cov = np.cov(class_points.T)
_plot_confidence_ellipse(ax2, mean, cov, color=colors[i], alpha=0.25)
ax2.set_xlabel(f"PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)")
ax2.set_ylabel(f"PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)")
ax2.set_title("PCA 质心 + 1σ 置信椭圆")
fig.suptitle("PCA 降维分析", fontsize=14)
fig.tight_layout()
path = os.path.join(out_dir, "04_pca.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[OK] {path}")
# 方差解释率
fig2, ax = plt.subplots(figsize=(8, 4))
n = min(30, len(pca.explained_variance_ratio_))
cumsum = np.cumsum(pca.explained_variance_ratio_[:n])
ax.bar(range(1, n + 1), pca.explained_variance_ratio_[:n],
alpha=0.6, color="steelblue", label="个体")
ax.plot(range(1, n + 1), cumsum, "ro-", markersize=4, label="累计")
ax.set_xlabel("主成分")
ax.set_ylabel("方差解释率")
ax.set_title("PCA 方差解释率")
ax.legend()
fig2.tight_layout()
path2 = os.path.join(out_dir, "04_pca_variance.png")
fig2.savefig(path2, dpi=150)
plt.close(fig2)
print(f"[OK] {path2}")
def plot_tsne(X, y, label_names, out_dir: str, max_samples: int = 5000):
from sklearn.manifold import TSNE
if X.shape[0] > max_samples:
print(f"[INFO] t-SNE: 样本过多 ({X.shape[0]}), 分层抽样至 {max_samples}")
indices = []
per_class = max_samples // len(label_names)
for i in range(len(label_names)):
idx_i = np.where(y == i)[0]
if len(idx_i) <= per_class:
indices.extend(idx_i.tolist())
else:
rng = np.random.RandomState(42)
indices.extend(rng.choice(idx_i, per_class, replace=False).tolist())
indices = np.array(indices)
X_sub = X[indices]
y_sub = y[indices]
else:
X_sub = X
y_sub = y
print("[INFO] 正在计算 t-SNE可能需要一些时间...")
perplexity = min(50, max(5, X_sub.shape[0] // 3))
tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity,
n_iter=1000, verbose=0)
X_tsne = tsne.fit_transform(X_sub)
colors = [TAB10[i % 10] for i in range(len(label_names))]
fig, ax = plt.subplots(figsize=(10, 8))
for i, lbl in enumerate(label_names):
mask = y_sub == i
ax.scatter(X_tsne[mask, 0], X_tsne[mask, 1], c=[colors[i]], label=f"{lbl}",
alpha=0.5, s=3, edgecolors="none")
ax.set_xlabel("t-SNE 1")
ax.set_ylabel("t-SNE 2")
ax.set_title(f"t-SNE 降维散点图 (n={X_sub.shape[0]}, perplexity={perplexity})")
ax.legend(fontsize=7, markerscale=3, loc="best")
fig.tight_layout()
path = os.path.join(out_dir, "05_tsne.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[OK] {path}")
def plot_class_mean_std_heatmap(all_features, label_names, out_dir: str, max_features: int = 30):
n_features = min(max_features, all_features[0][1].shape[1])
n_classes = len(label_names)
mean_matrix = np.zeros((n_classes, n_features))
std_matrix = np.zeros((n_classes, n_features))
for i, (lbl, df) in enumerate(all_features):
for j in range(min(n_features, df.shape[1])):
col = df.iloc[:, j].values
mean_matrix[i, j] = np.mean(col)
std_matrix[i, j] = np.std(col)
for j in range(df.shape[1], n_features):
mean_matrix[i, j] = 0.0
std_matrix[i, j] = 0.0
fig, axes = plt.subplots(1, 2, figsize=(max(10, n_features * 0.35), max(5, n_classes * 0.5)))
sns.heatmap(mean_matrix, ax=axes[0], cmap="RdBu_r", center=0,
xticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
yticklabels=label_names,
annot=n_features <= 20, fmt=".3f" if n_features <= 20 else "",
linewidths=0.5, cbar_kws={"label": "均值", "shrink": 0.8})
axes[0].set_title("各类别特征均值")
axes[0].set_xlabel("特征维度")
sns.heatmap(std_matrix, ax=axes[1], cmap="YlOrRd",
xticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
yticklabels=label_names,
annot=n_features <= 20, fmt=".3f" if n_features <= 20 else "",
linewidths=0.5, cbar_kws={"label": "标准差", "shrink": 0.8})
axes[1].set_title("各类别特征标准差")
axes[1].set_xlabel("特征维度")
fig.suptitle("各类别特征统计对比", fontsize=13)
fig.tight_layout()
path = os.path.join(out_dir, "06_class_mean_std_heatmap.png")
fig.savefig(path, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f"[OK] {path}")
def plot_correlation_heatmap(X, out_dir: str, max_features: int = 30):
n_features = min(max_features, X.shape[1])
X_sub = X[:, :n_features]
corr = np.corrcoef(X_sub.T)
fig, ax = plt.subplots(figsize=(max(10, n_features * 0.5), max(8, n_features * 0.45)))
sns.heatmap(corr, ax=ax, cmap="RdBu_r", center=0, vmin=-1, vmax=1,
xticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
yticklabels=[f"F{i+1}" for i in range(n_features)] if n_features <= 30 else False,
linewidths=0.1, cbar_kws={"label": "Pearson r", "shrink": 0.8})
ax.set_title(f"特征相关性矩阵 (前 {n_features} 维)")
fig.tight_layout()
path = os.path.join(out_dir, "07_correlation_heatmap.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[OK] {path}")
def plot_global_distribution(X, out_dir: str):
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
all_vals = X.flatten()
axes[0].hist(all_vals, bins=100, color="steelblue", alpha=0.7, edgecolor="white",
linewidth=0.3)
axes[0].axvline(np.mean(all_vals), color="red", linestyle="--",
label=f"均值={np.mean(all_vals):.4f}")
axes[0].axvline(np.median(all_vals), color="orange", linestyle="--",
label=f"中位数={np.median(all_vals):.4f}")
axes[0].set_xlabel("特征值")
axes[0].set_ylabel("频数")
axes[0].set_title("全局特征值分布")
axes[0].legend()
means = np.mean(X, axis=0)
stds = np.std(X, axis=0)
n_features = min(X.shape[1], 50)
axes[1].errorbar(range(1, n_features + 1), means[:n_features], yerr=stds[:n_features],
fmt="o", markersize=3, capsize=2, color="steelblue", alpha=0.7)
axes[1].set_xlabel("特征维度")
axes[1].set_ylabel("均值 ± 标准差")
axes[1].set_title(f"各维度均值与标准差 (前 {n_features} 维)")
fig.suptitle("全局特征概览", fontsize=13)
fig.tight_layout()
path = os.path.join(out_dir, "08_global_distribution.png")
fig.savefig(path, dpi=150)
plt.close(fig)
print(f"[OK] {path}")
# ============================================================
# 命令行入口
# ============================================================
def main():
parser = argparse.ArgumentParser(
description="Deeplearning 数据可视化脚本",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9
python Scripts/visualize.py -f "20260408 grap" -l 1 2 3 4 5 6 7 8 9
python Scripts/visualize.py -f 20260319Numbers -l 0 1 2 3 4 5 6 7 8 9 --no-tsne --max-features 15
""",
)
parser.add_argument("-f", "--folder", required=True,
help="Static/ 下的数据目录名")
parser.add_argument("-l", "--labels", nargs="+", required=True,
help="类别标签列表,空格分隔")
parser.add_argument("-r", "--root", default=None,
help="项目根目录(默认为 Deeplearning/")
parser.add_argument("--max-features", type=int, default=20,
help="可视化中显示的最大特征维度数 (默认 20)")
parser.add_argument("--no-tsne", action="store_true",
help="跳过 t-SNE 计算")
parser.add_argument("--no-pca", action="store_true",
help="跳过 PCA 计算")
parser.add_argument("--tsne-max-samples", type=int, default=5000,
help="t-SNE 最大抽样数 (默认 5000)")
args = parser.parse_args()
if args.root:
root = args.root
else:
root = str(Path(__file__).resolve().parent.parent)
root = os.path.abspath(root)
print(f"加载数据: {root}/Static/{args.folder}")
X, y, label_names, all_features, col_counts = load_all_data(root, args.folder, args.labels)
print(f" 样本数: {X.shape[0]}, 特征维度: {X.shape[1]}, 类别数: {len(label_names)}")
for lbl in label_names:
cnt = int(np.sum(y == label_names.index(lbl)))
print(f"{lbl}: {cnt} 样本, {col_counts.get(lbl, X.shape[1])}")
out_dir = os.path.join(root, "Visualizations", args.folder)
_ensure_dir(out_dir)
print(f"\n输出目录: {out_dir}\n")
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["DejaVu Sans"]
print("生成可视化图表...\n")
plot_class_distribution(y, label_names, out_dir)
plot_feature_histograms(all_features, out_dir,
max_features=min(args.max_features, X.shape[1]))
plot_feature_boxplots(all_features, out_dir,
max_features=min(args.max_features, X.shape[1]))
if not args.no_pca:
plot_pca(X, y, label_names, out_dir)
if not args.no_tsne:
plot_tsne(X, y, label_names, out_dir, max_samples=args.tsne_max_samples)
plot_class_mean_std_heatmap(all_features, label_names, out_dir,
max_features=min(args.max_features, X.shape[1]))
plot_correlation_heatmap(X, out_dir, max_features=min(30, X.shape[1]))
plot_global_distribution(X, out_dir)
print(f"\n全部图表已保存到: {out_dir}")
if __name__ == "__main__":
main()

View File

@ -1,180 +0,0 @@
name: Deeplearning
channels:
- pytorch
- nvidia
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- blas=1.0=mkl
- bottleneck=1.3.7=py312ha883a20_0
- brotli=1.0.9=h5eee18b_8
- brotli-bin=1.0.9=h5eee18b_8
- brotli-python=1.0.9=py312h6a678d5_8
- bzip2=1.0.8=h5eee18b_6
- ca-certificates=2024.11.26=h06a4308_0
- certifi=2024.12.14=py312h06a4308_0
- charset-normalizer=3.3.2=pyhd3eb1b0_0
- contourpy=1.2.0=py312hdb19cb5_0
- cuda-cudart=12.4.127=0
- cuda-cupti=12.4.127=0
- cuda-libraries=12.4.0=0
- cuda-nvrtc=12.4.127=0
- cuda-nvtx=12.4.127=0
- cuda-opencl=12.4.127=0
- cuda-runtime=12.4.0=0
- cudatoolkit=11.5.1=hcf5317a_9
- cycler=0.11.0=pyhd3eb1b0_0
- cyrus-sasl=2.1.28=h52b45da_1
- dbus=1.13.18=hb2f20db_0
- debugpy=1.6.7=py312h6a678d5_0
- et_xmlfile=1.1.0=py312h06a4308_1
- expat=2.6.2=h6a678d5_0
- ffmpeg=4.3=hf484d3e_0
- filelock=3.13.1=py312h06a4308_0
- fontconfig=2.14.1=h55d465d_3
- fonttools=4.51.0=py312h5eee18b_0
- freetype=2.12.1=h4a9f257_0
- glib=2.78.4=h6a678d5_0
- glib-tools=2.78.4=h6a678d5_0
- gmp=6.2.1=h295c915_3
- gnutls=3.6.15=he1e5248_0
- gst-plugins-base=1.14.1=h6a678d5_1
- gstreamer=1.14.1=h5eee18b_1
- icu=73.1=h6a678d5_0
- idna=3.7=py312h06a4308_0
- intel-openmp=2023.1.0=hdb19cb5_46306
- jinja2=3.1.4=py312h06a4308_0
- joblib=1.4.2=py312h06a4308_0
- jpeg=9e=h5eee18b_3
- kiwisolver=1.4.4=py312h6a678d5_0
- krb5=1.20.1=h143b758_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libbrotlicommon=1.0.9=h5eee18b_8
- libbrotlidec=1.0.9=h5eee18b_8
- libbrotlienc=1.0.9=h5eee18b_8
- libclang=14.0.6=default_hc6dbbc7_1
- libclang13=14.0.6=default_he11475f_1
- libcublas=12.4.2.65=0
- libcufft=11.2.0.44=0
- libcufile=1.9.1.3=0
- libcups=2.4.2=h2d74bed_1
- libcurand=10.3.5.147=0
- libcusolver=11.6.0.99=0
- libcusparse=12.3.0.142=0
- libdeflate=1.17=h5eee18b_1
- libedit=3.1.20230828=h5eee18b_0
- libffi=3.4.4=h6a678d5_1
- libgcc-ng=11.2.0=h1234567_1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libglib=2.78.4=hdc74915_0
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h5eee18b_3
- libidn2=2.3.4=h5eee18b_0
- libjpeg-turbo=2.0.0=h9bf148f_0
- libllvm14=14.0.6=hdb19cb5_3
- libnpp=12.2.5.2=0
- libnvfatbin=12.4.127=0
- libnvjitlink=12.4.99=0
- libnvjpeg=12.3.1.89=0
- libpng=1.6.39=h5eee18b_0
- libpq=12.17=hdbd6064_0
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libwebp-base=1.3.2=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=h097e994_2
- libxml2=2.13.1=hfdd30dd_2
- llvm-openmp=14.0.6=h9e868ea_0
- lz4-c=1.9.4=h6a678d5_1
- markupsafe=2.1.3=py312h5eee18b_0
- matplotlib=3.8.4=py312h06a4308_0
- matplotlib-base=3.8.4=py312h526ad5a_0
- mkl=2023.1.0=h213fc3f_46344
- mkl-service=2.4.0=py312h5eee18b_1
- mkl_fft=1.3.8=py312h5eee18b_0
- mkl_random=1.2.4=py312hdb19cb5_0
- mpmath=1.3.0=py312h06a4308_0
- mysql=5.7.24=h721c034_2
- ncurses=6.4=h6a678d5_0
- nettle=3.7.3=hbbd107a_1
- networkx=3.3=py312h06a4308_0
- numexpr=2.8.7=py312hf827012_0
- numpy=1.26.4=py312hc5e2394_0
- numpy-base=1.26.4=py312h0da6c21_0
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.5.2=he7f1fd0_0
- openpyxl=3.1.5=py312h5eee18b_0
- openssl=3.0.15=h5eee18b_0
- packaging=24.1=py312h06a4308_0
- pandas=2.2.2=py312h526ad5a_0
- pcre2=10.42=hebb0a14_1
- pillow=10.4.0=py312h5eee18b_0
- pip=24.2=py312h06a4308_0
- ply=3.11=py312h06a4308_1
- pybind11-abi=5=hd3eb1b0_0
- pyopengl=3.1.1a1=py312h06a4308_0
- pyparsing=3.0.9=py312h06a4308_0
- pyqt=5.15.10=py312h6a678d5_0
- pyqt5-sip=12.13.0=py312h5eee18b_0
- pysocks=1.7.1=py312h06a4308_0
- python=3.12.4=h5148396_1
- python-dateutil=2.9.0post0=py312h06a4308_2
- python-tzdata=2023.3=pyhd3eb1b0_0
- pytorch=2.4.0=py3.12_cuda12.4_cudnn9.1.0_0
- pytorch-cuda=12.4=hc786d27_6
- pytorch-mutex=1.0=cuda
- pytz=2024.1=py312h06a4308_0
- pyyaml=6.0.1=py312h5eee18b_0
- qt-main=5.15.2=h53bd1ea_10
- readline=8.2=h5eee18b_0
- requests=2.32.3=py312h06a4308_0
- scikit-learn=1.5.1=py312h526ad5a_0
- scipy=1.13.1=py312hc5e2394_0
- seaborn=0.13.2=py312h06a4308_0
- setuptools=72.1.0=py312h06a4308_0
- sip=6.7.12=py312h6a678d5_0
- six=1.16.0=pyhd3eb1b0_1
- sqlite=3.45.3=h5eee18b_0
- sympy=1.12=py312h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- threadpoolctl=3.5.0=py312he106c6f_0
- tk=8.6.14=h39e8969_0
- torchaudio=2.4.0=py312_cu124
- torchtriton=3.0.0=py312
- torchvision=0.19.0=py312_cu124
- tornado=6.4.1=py312h5eee18b_0
- tqdm=4.66.5=py312he106c6f_0
- typing_extensions=4.11.0=py312h06a4308_0
- tzdata=2024a=h04d1e81_0
- unicodedata2=15.1.0=py312h5eee18b_0
- urllib3=2.2.2=py312h06a4308_0
- wheel=0.43.0=py312h06a4308_0
- xlrd=2.0.1=pyhd3eb1b0_1
- xz=5.4.6=h5eee18b_1
- yaml=0.2.5=h7b6447c_0
- zlib=1.2.13=h5eee18b_1
- zstd=1.5.5=hc292b87_2
- pip:
- autopep8==2.3.1
- basedpyright==1.16.0
- black==24.8.0
- click==8.1.7
- fsspec==2024.6.1
- graphviz==0.20.3
- greenlet==3.0.3
- msgpack==1.0.8
- mypy-extensions==1.0.0
- nodejs-wheel-binaries==20.16.0
- pathspec==0.12.1
- platformdirs==4.2.2
- pycodestyle==2.12.1
- pynvim==0.5.0
prefix: /home/qyhhh/.miniconda3/envs/Deeplearning

View File

@ -1,163 +0,0 @@
# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: linux-64
# created-by: conda 26.1.1
@EXPLICIT
https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda
https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.conda
https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2024.11.26-h06a4308_0.conda
https://conda.anaconda.org/nvidia/linux-64/cuda-cudart-12.4.127-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cuda-cupti-12.4.127-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cuda-nvrtc-12.4.127-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cuda-nvtx-12.4.127-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/cuda-opencl-12.4.127-0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/ld_impl_linux-64-2.38-h1181459_1.conda
https://conda.anaconda.org/nvidia/linux-64/libcublas-12.4.2.65-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libcufft-11.2.0.44-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libcufile-1.9.1.3-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libcurand-10.3.5.147-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libcusolver-11.6.0.99-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libcusparse-12.3.0.142-0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libgfortran5-11.2.0-h1234567_1.conda
https://conda.anaconda.org/nvidia/linux-64/libnpp-12.2.5.2-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libnvfatbin-12.4.127-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libnvjitlink-12.4.99-0.tar.bz2
https://conda.anaconda.org/nvidia/linux-64/libnvjpeg-12.3.1.89-0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-11.2.0-h1234567_1.conda
https://repo.anaconda.com/pkgs/main/noarch/pybind11-abi-5-hd3eb1b0_0.conda
https://conda.anaconda.org/pytorch/noarch/pytorch-mutex-1.0-cuda.tar.bz2
https://repo.anaconda.com/pkgs/main/noarch/tzdata-2024a-h04d1e81_0.conda
https://conda.anaconda.org/nvidia/linux-64/cuda-libraries-12.4.0-0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-11.2.0-h00389a5_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/libgomp-11.2.0-h1234567_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/_openmp_mutex-5.1-1_gnu.conda
https://conda.anaconda.org/nvidia/linux-64/cuda-runtime-12.4.0-0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-11.2.0-h1234567_1.conda
https://conda.anaconda.org/pytorch/linux-64/pytorch-cuda-12.4-hc786d27_6.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/bzip2-1.0.8-h5eee18b_6.conda
https://conda.anaconda.org/nvidia/linux-64/cudatoolkit-11.5.1-hcf5317a_9.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/expat-2.6.2-h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/gmp-6.2.1-h295c915_3.conda
https://repo.anaconda.com/pkgs/main/linux-64/icu-73.1-h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9e-h5eee18b_3.conda
https://repo.anaconda.com/pkgs/main/linux-64/lame-3.100-h7b6447c_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/lerc-3.0-h295c915_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libbrotlicommon-1.0.9-h5eee18b_8.conda
https://repo.anaconda.com/pkgs/main/linux-64/libdeflate-1.17-h5eee18b_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.4.4-h6a678d5_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/libiconv-1.16-h5eee18b_3.conda
https://conda.anaconda.org/pytorch/linux-64/libjpeg-turbo-2.0.0-h9bf148f_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/libtasn1-4.19.0-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libunistring-0.9.10-h27cfd23_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libuuid-1.41.5-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libwebp-base-1.3.2-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.15-h7f8727e_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.9.4-h6a678d5_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.4-h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/openh264-2.1.1-h4ff587b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/openssl-3.0.15-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/tbb-2021.8.0-hdb19cb5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/xz-5.4.6-h5eee18b_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/yaml-0.2.5-h7b6447c_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.13-h5eee18b_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2023.1.0-hdb19cb5_46306.conda
https://repo.anaconda.com/pkgs/main/linux-64/libbrotlidec-1.0.9-h5eee18b_8.conda
https://repo.anaconda.com/pkgs/main/linux-64/libbrotlienc-1.0.9-h5eee18b_8.conda
https://repo.anaconda.com/pkgs/main/linux-64/libcups-2.4.2-h2d74bed_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20230828-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libidn2-2.3.4-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libllvm14-14.0.6-hdb19cb5_3.conda
https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.39-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.13.1-hfdd30dd_2.conda
https://repo.anaconda.com/pkgs/main/linux-64/llvm-openmp-14.0.6-h9e868ea_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/nettle-3.7.3-hbbd107a_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/pcre2-10.42-hebb0a14_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/readline-8.2-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.14-h39e8969_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.5.5-hc292b87_2.conda
https://repo.anaconda.com/pkgs/main/linux-64/brotli-bin-1.0.9-h5eee18b_8.conda
https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.12.1-h4a9f257_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/gnutls-3.6.15-he1e5248_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/krb5-1.20.1-h143b758_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/libclang13-14.0.6-default_he11475f_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/libglib-2.78.4-hdc74915_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.5.1-h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libxkbcommon-1.0.1-h097e994_2.conda
https://repo.anaconda.com/pkgs/main/linux-64/mkl-2023.1.0-h213fc3f_46344.conda
https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.45.3-h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/brotli-1.0.9-h5eee18b_8.conda
https://repo.anaconda.com/pkgs/main/linux-64/cyrus-sasl-2.1.28-h52b45da_1.conda
https://conda.anaconda.org/pytorch/linux-64/ffmpeg-4.3-hf484d3e_0.tar.bz2
https://repo.anaconda.com/pkgs/main/linux-64/fontconfig-2.14.1-h55d465d_3.conda
https://repo.anaconda.com/pkgs/main/linux-64/glib-tools-2.78.4-h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/lcms2-2.12-h3be6417_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/libclang-14.0.6-default_hc6dbbc7_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/libpq-12.17-hdbd6064_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/openjpeg-2.5.2-he7f1fd0_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/python-3.12.4-h5148396_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/brotli-python-1.0.9-py312h6a678d5_8.conda
https://repo.anaconda.com/pkgs/main/linux-64/certifi-2024.12.14-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/noarch/charset-normalizer-3.3.2-pyhd3eb1b0_0.conda
https://repo.anaconda.com/pkgs/main/noarch/cycler-0.11.0-pyhd3eb1b0_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/debugpy-1.6.7-py312h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/et_xmlfile-1.1.0-py312h06a4308_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/filelock-3.13.1-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/glib-2.78.4-h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/idna-3.7-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/joblib-1.4.2-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.4.4-py312h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/markupsafe-2.1.3-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.4.0-py312h5eee18b_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/mpmath-1.3.0-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/mysql-5.7.24-h721c034_2.conda
https://repo.anaconda.com/pkgs/main/linux-64/networkx-3.3-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/packaging-24.1-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pillow-10.4.0-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/ply-3.11-py312h06a4308_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/pyopengl-3.1.1a1-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pyparsing-3.0.9-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pyqt5-sip-12.13.0-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pysocks-1.7.1-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/noarch/python-tzdata-2023.3-pyhd3eb1b0_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pytz-2024.1-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pyyaml-6.0.1-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/setuptools-72.1.0-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/noarch/six-1.16.0-pyhd3eb1b0_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/threadpoolctl-3.5.0-py312he106c6f_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/tornado-6.4.1-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/tqdm-4.66.5-py312he106c6f_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/typing_extensions-4.11.0-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/unicodedata2-15.1.0-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/wheel-0.43.0-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/noarch/xlrd-2.0.1-pyhd3eb1b0_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/dbus-1.13.18-hb2f20db_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/fonttools-4.51.0-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/gstreamer-1.14.1-h5eee18b_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/jinja2-3.1.4-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.26.4-py312h0da6c21_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/openpyxl-3.1.5-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pip-24.2-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/python-dateutil-2.9.0post0-py312h06a4308_2.conda
https://repo.anaconda.com/pkgs/main/linux-64/sip-6.7.12-py312h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/sympy-1.12-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/urllib3-2.2.2-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/gst-plugins-base-1.14.1-h6a678d5_1.conda
https://repo.anaconda.com/pkgs/main/linux-64/requests-2.32.3-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/qt-main-5.15.2-h53bd1ea_10.conda
https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.15.10-py312h6a678d5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/bottleneck-1.3.7-py312ha883a20_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/contourpy-1.2.0-py312hdb19cb5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.8.4-py312h06a4308_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-base-3.8.4-py312h526ad5a_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.3.8-py312h5eee18b_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/mkl_random-1.2.4-py312hdb19cb5_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.26.4-py312hc5e2394_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/numexpr-2.8.7-py312hf827012_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.13.1-py312hc5e2394_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/pandas-2.2.2-py312h526ad5a_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/scikit-learn-1.5.1-py312h526ad5a_0.conda
https://repo.anaconda.com/pkgs/main/linux-64/seaborn-0.13.2-py312h06a4308_0.conda
https://conda.anaconda.org/pytorch/linux-64/pytorch-2.4.0-py3.12_cuda12.4_cudnn9.1.0_0.tar.bz2
https://conda.anaconda.org/pytorch/linux-64/torchaudio-2.4.0-py312_cu124.tar.bz2
https://conda.anaconda.org/pytorch/linux-64/torchtriton-3.0.0-py312.tar.bz2
https://conda.anaconda.org/pytorch/linux-64/torchvision-0.19.0-py312_cu124.tar.bz2

View File

@ -1,179 +0,0 @@
name: Deeplearning
channels:
- defaults
- nvidia
- pytorch
dependencies:
- _libgcc_mutex=0.1
- _openmp_mutex=5.1
- blas=1.0
- bottleneck=1.3.7
- brotli=1.0.9
- brotli-bin=1.0.9
- brotli-python=1.0.9
- bzip2=1.0.8
- ca-certificates=2024.11.26
- certifi=2024.12.14
- charset-normalizer=3.3.2
- contourpy=1.2.0
- cuda-cudart=12.4.127
- cuda-cupti=12.4.127
- cuda-libraries=12.4.0
- cuda-nvrtc=12.4.127
- cuda-nvtx=12.4.127
- cuda-opencl=12.4.127
- cuda-runtime=12.4.0
- cudatoolkit=11.5.1
- cycler=0.11.0
- cyrus-sasl=2.1.28
- dbus=1.13.18
- debugpy=1.6.7
- et_xmlfile=1.1.0
- expat=2.6.2
- ffmpeg=4.3
- filelock=3.13.1
- fontconfig=2.14.1
- fonttools=4.51.0
- freetype=2.12.1
- glib=2.78.4
- glib-tools=2.78.4
- gmp=6.2.1
- gnutls=3.6.15
- gst-plugins-base=1.14.1
- gstreamer=1.14.1
- icu=73.1
- idna=3.7
- intel-openmp=2023.1.0
- jinja2=3.1.4
- joblib=1.4.2
- jpeg=9e
- kiwisolver=1.4.4
- krb5=1.20.1
- lame=3.100
- lcms2=2.12
- ld_impl_linux-64=2.38
- lerc=3.0
- libbrotlicommon=1.0.9
- libbrotlidec=1.0.9
- libbrotlienc=1.0.9
- libclang=14.0.6
- libclang13=14.0.6
- libcublas=12.4.2.65
- libcufft=11.2.0.44
- libcufile=1.9.1.3
- libcups=2.4.2
- libcurand=10.3.5.147
- libcusolver=11.6.0.99
- libcusparse=12.3.0.142
- libdeflate=1.17
- libedit=3.1.20230828
- libffi=3.4.4
- libgcc-ng=11.2.0
- libgfortran-ng=11.2.0
- libgfortran5=11.2.0
- libglib=2.78.4
- libgomp=11.2.0
- libiconv=1.16
- libidn2=2.3.4
- libjpeg-turbo=2.0.0
- libllvm14=14.0.6
- libnpp=12.2.5.2
- libnvfatbin=12.4.127
- libnvjitlink=12.4.99
- libnvjpeg=12.3.1.89
- libpng=1.6.39
- libpq=12.17
- libstdcxx-ng=11.2.0
- libtasn1=4.19.0
- libtiff=4.5.1
- libunistring=0.9.10
- libuuid=1.41.5
- libwebp-base=1.3.2
- libxcb=1.15
- libxkbcommon=1.0.1
- libxml2=2.13.1
- llvm-openmp=14.0.6
- lz4-c=1.9.4
- markupsafe=2.1.3
- matplotlib=3.8.4
- matplotlib-base=3.8.4
- mkl=2023.1.0
- mkl-service=2.4.0
- mkl_fft=1.3.8
- mkl_random=1.2.4
- mpmath=1.3.0
- mysql=5.7.24
- ncurses=6.4
- nettle=3.7.3
- networkx=3.3
- numexpr=2.8.7
- numpy=1.26.4
- numpy-base=1.26.4
- openh264=2.1.1
- openjpeg=2.5.2
- openpyxl=3.1.5
- openssl=3.0.15
- packaging=24.1
- pandas=2.2.2
- pcre2=10.42
- pillow=10.4.0
- pip=24.2
- ply=3.11
- pybind11-abi=5
- pyopengl=3.1.1a1
- pyparsing=3.0.9
- pyqt=5.15.10
- pyqt5-sip=12.13.0
- pysocks=1.7.1
- python=3.12.4
- python-dateutil=2.9.0post0
- python-tzdata=2023.3
- pytorch=2.4.0
- pytorch-cuda=12.4
- pytorch-mutex=1.0
- pytz=2024.1
- pyyaml=6.0.1
- qt-main=5.15.2
- readline=8.2
- requests=2.32.3
- scikit-learn=1.5.1
- scipy=1.13.1
- seaborn=0.13.2
- setuptools=72.1.0
- sip=6.7.12
- six=1.16.0
- sqlite=3.45.3
- sympy=1.12
- tbb=2021.8.0
- threadpoolctl=3.5.0
- tk=8.6.14
- torchaudio=2.4.0
- torchtriton=3.0.0
- torchvision=0.19.0
- tornado=6.4.1
- tqdm=4.66.5
- typing_extensions=4.11.0
- tzdata=2024a
- unicodedata2=15.1.0
- urllib3=2.2.2
- wheel=0.43.0
- xlrd=2.0.1
- xz=5.4.6
- yaml=0.2.5
- zlib=1.2.13
- zstd=1.5.5
- pip:
- autopep8==2.3.1
- basedpyright==1.16.0
- black==24.8.0
- click==8.1.7
- fsspec==2024.6.1
- graphviz==0.20.3
- greenlet==3.0.3
- msgpack==1.0.8
- mypy-extensions==1.0.0
- nodejs-wheel-binaries==20.16.0
- pathspec==0.12.1
- platformdirs==4.2.2
- pycodestyle==2.12.1
- pynvim==0.5.0

46
init.lua Normal file
View File

@ -0,0 +1,46 @@
G = require('G')
if not G.g.vscode then
local lazypath = G.fn.stdpath("data") .. "/lazy/lazy.nvim"
if not G.loop.fs_stat(lazypath) then
G.fn.system({
"git",
"clone",
"--filter=blob:none",
"https://github.com/folke/lazy.nvim.git",
"--branch=stable",
lazypath,
})
G.fn.system({
"git",
"config",
"--global",
"credential.helper",
"store",
})
end
G.opt.rtp:prepend(lazypath)
require("keymap")
if G.g.vscode then
require('vscode')
else
require("options")
local clone_prefix = G.use_ssh and "git@github.com:%s.git" or "https://github.com/%s.git"
require("lazy").setup(
require('plugs'), {
lockfile = G.fn.stdpath("data") .. "/lazy/lazy-lock.json",
git = {
url_format = clone_prefix,
}
}
)
end
end

59
lua/G.lua Normal file
View File

@ -0,0 +1,59 @@
local G = {}
G.use_ssh = false
G.g = vim.g
G.b = vim.b
G.o = vim.o
G.fn = vim.fn
G.api = vim.api
G.opt = vim.opt
G.loop = vim.loop
G.lb = vim.lsp.buf
G.dic = vim.diagnostic
G.cgp = vim.nvim_create_augroup
function G.map(maps)
for _, map in pairs(maps) do
if #map == 3 then
vim.keymap.set(map[1], map[2], map[3], { noremap = true })
elseif #map == 4 then
vim.keymap.set(map[1], map[2], map[3], map[4])
else
print("太多变量了")
end
--G.api.nvim_set_keymap(map[1], map[2], map[3], map[4])
end
end
function G.delmap(maps)
for _, map in pairs(maps) do
if #map == 2 then
vim.keymap.del(map[1], map[2], {})
elseif #map == 3 then
vim.keymap.del(map[1], map[2], map[3])
else
print("太多变量了")
end
end
end
function G.cmd(cmd)
G.api.nvim_command(cmd)
end
function G.exec(c)
G.api.nvim_exec(c)
end
function G.eval(c)
return G.api.nvim_eval(c)
end
function G.au(even, opts)
return G.api.nvim_create_autocmd(even, opts)
end
G.g.mapleader = ' '
return G

46
lua/keymap.lua Normal file
View File

@ -0,0 +1,46 @@
local opt = { noremap = true }
-- base
G.map({
{ 'n', '<leader>nh', ':nohlsearch<CR>', opt },
{ 'n', '<leader>rp', ':%s/', opt },
{ 'v', 'L', '$', opt },
{ 'v', 'H', '^', opt },
{ 'n', 'L', '$', opt },
{ 'n', 'H', '^', opt },
{ 'v', '>', '>gv', opt },
{ 'v', '<', '<gv', opt },
{ 'n', '>', '>>', opt },
{ 'n', '<', '<<', opt },
{ 'n', '?', ':set hlsearch<CR>?', opt },
{ 'n', '/', ':set hlsearch<CR>/', opt },
{ 'n', '<A-l>', ':tabn<CR>', opt },
{ 'n', '<A-h>', ':tabp<CR>', opt },
{ 'n', '<c-j>', '<c-w>j', opt },
{ 'n', '<c-h>', '<c-w>h', opt },
{ 'n', '<c-k>', '<c-w>k', opt },
{ 'n', '<c-l>', '<c-w>l', opt },
{ 'n', '<c-c>', ':q<CR>', opt },
{ 'n', '<c-S>', ':w !sudo tee %<CR>', opt },
{ 'n', '<c-q>', ':q!<CR>', opt },
{ 'v', '<cs-y>', '"+y', opt },
{ 'n', '<leader>y', 'ggyG', opt },
{ 'n', '<leader>p', 'ggpG', opt },
{ 'n', '<leader>v', 'ggVG', opt },
{ 'n', '<up>', ':res -5<CR>', opt },
{ 'n', '<down>', ':res +5<CR>', opt },
{ 'n', '<left>', ':vertical resize -5<CR>', opt },
{ 'n', '<right>', ':vertical resize +5<CR>', opt },
})

22
lua/lsp/basedpyright.lua Normal file
View File

@ -0,0 +1,22 @@
return {
-- capabilities = require("cmp_nvim_lsp").default_capabilities(),
settings = {
basedpyright = {
analysis = {
autoSearchPaths = true,
diagnosticMode = "openFilesOnly",
useLibraryCodeForTypes = true,
typeCheckingMode = "standard"
},
},
},
on_attach = function()
G.api.nvim_create_user_command('R', function()
G.cmd [[set splitbelow]]
G.cmd [[sp]]
G.cmd [[term python3 %]]
G.cmd [[resize 10]]
G.cmd [[startinsert]]
end, {})
end
}

11
lua/lsp/bash.lua Normal file
View File

@ -0,0 +1,11 @@
return{
on_attach = function()
G.api.nvim_create_user_command('R', function()
G.cmd [[set splitbelow]]
G.cmd [[sp]]
G.cmd [[term sh %]]
G.cmd [[resize 10]]
G.cmd [[startinsert]]
end, {})
end
}

11
lua/lsp/c.lua Normal file
View File

@ -0,0 +1,11 @@
return {
on_attach = function()
G.api.nvim_create_user_command('R', function()
G.cmd [[set splitbelow]]
G.cmd [[sp]]
G.cmd [[term g++ "%" -o "%<" && ./"%<" && rm "%<"]]
G.cmd [[resize 10]]
G.cmd [[startinsert]]
end, {})
end
}

18
lua/lsp/go.lua Normal file
View File

@ -0,0 +1,18 @@
return {
on_attach = function()
G.api.nvim_create_user_command('R', function()
G.cmd [[set splitbelow]]
G.cmd [[sp]]
G.cmd [[term go run %]]
G.cmd [[resize 10]]
G.cmd [[startinsert]]
end, {})
G.api.nvim_create_user_command('Rgin', function()
G.cmd [[set splitbelow]]
G.cmd [[sp]]
G.cmd [[term go run ./main.go]]
G.cmd [[resize 10]]
G.cmd [[startinsert]]
end, {})
end
}

4
lua/lsp/json.lua Normal file
View File

@ -0,0 +1,4 @@
return {
on_attach = function()
end
}

32
lua/lsp/lua.lua Normal file
View File

@ -0,0 +1,32 @@
local runtime_path = vim.split(package.path, ';')
table.insert(runtime_path, "lua/?.lua")
table.insert(runtime_path, "lua/?/init.lua")
return {
capabilities = require('cmp_nvim_lsp').default_capabilities(),
settings = {
Lua = {
runtime = {
-- Tell the language server which version of Lua you're using (most likely LuaJIT in the case of Neovim)
version = 'LuaJIT',
-- Setup your lua path
path = runtime_path,
},
diagnostics = {
globals = {'vim', 'G', 'yield', 'Candidate'},
},
workspace = {
-- Make the server aware of Neovim runtime files
library = vim.api.nvim_get_runtime_file("", true),
},
-- Do not send telemetry data containing a randomized but unique identifier
telemetry = {
enable = false
},
},
},
}

2
lua/lsp/markdown.lua Normal file
View File

@ -0,0 +1,2 @@
return {
}

11
lua/lsp/pyright.lua Normal file
View File

@ -0,0 +1,11 @@
return {
on_attach = function()
G.api.nvim_create_user_command('R', function()
G.cmd [[set splitbelow]]
G.cmd [[sp]]
G.cmd [[term python3 %]]
G.cmd [[resize 10]]
G.cmd [[startinsert]]
end, {})
end
}

12
lua/lsp/yaml.lua Normal file
View File

@ -0,0 +1,12 @@
return {
capabilities = require('cmp_nvim_lsp').default_capabilities(),
settings = {
yaml = {
schemas = {
["https://json.schemastore.org/github-workflow.json"] = "/.github/workflows/*",
["../path/relative/to/file.yml"] = "/.github/workflows/*",
["/path/from/root/of/project"] = "/.github/workflows/*",
},
},
}
}

91
lua/options.lua Normal file
View File

@ -0,0 +1,91 @@
--
G.opt.ttimeout = true
G.opt.ttimeoutlen = 100
-- 行号
G.opt.nu = true
G.opt.rnu = true
G.opt.scrolloff = 999
-- 自动保存
G.opt.autowrite = true
G.opt.autowriteall = true
-- tab键
G.opt.sw = 2
G.opt.ts = 2
G.opt.softtabstop = 2
G.opt.smarttab = true
G.opt.expandtab = true
G.opt.autoindent = true
-- 光标
G.opt.cursorline = true
-- 分屏
G.opt.splitright = true
G.opt.splitbelow = true
-- 搜索
G.opt.ignorecase = true
G.opt.incsearch = true
-- 不换行
G.opt.textwidth = 999
G.opt.wrap = false
-- 文件判断
G.cmd("filetype plugin indent on")
-- 取消换行注释
G.au({ "BufEnter" }, {
pattern = { "*" },
callback = function()
-- vim.opt.formatoptions = vim.opt.formatoptions - { "c", "r", "o" }
G.opt.formatoptions = G.opt.formatoptions
- "o" -- O and o, don't continue comments
+ "r" -- But do continue when pressing enter.
end,
})
G.au({ "InsertEnter" }, {
pattern = { "*" },
callback = function()
G.opt.hlsearch = false
end,
})
G.au({ "VimEnter", "BufEnter" }, {
pattern = { "*.code-snippets" },
callback = function()
G.cmd("setfiletype json")
end,
})
-- G.au({
-- {"CmdlineEnter"},
-- {
-- if index
-- }
-- })
local function isempty(s)
return s == nil or s == ""
end
local function use_if_defined(val, fallback)
return val ~= nil and val or fallback
end
local conda_prefix = os.getenv("CONDA_PREFIX")
if not isempty(conda_prefix) then
vim.g.python_host_prog = use_if_defined(vim.g.python_host_prog, conda_prefix .. "/bin/python")
vim.g.python3_host_prog = use_if_defined(vim.g.python3_host_prog, conda_prefix .. "/bin/python")
else
vim.g.python_host_prog = use_if_defined(vim.g.python_host_prog, "python")
vim.g.python3_host_prog = use_if_defined(vim.g.python3_host_prog, "python3")
end

32
lua/plugs.lua Normal file
View File

@ -0,0 +1,32 @@
return {
require('plugs.nvim-lspconfig'),
require('plugs.nvimtree'), -- nvimtree
require('plugs.theme'), -- theme
require("plugs.edit-plugs"),
require("plugs.dev"),
-- leetcode刷题
-- {
-- "kawre/leetcode.nvim",
-- build = ":TSUpdate html",
-- dependencies = {
-- "nvim-telescope/telescope.nvim",
-- "nvim-lua/plenary.nvim", -- telescope 所需
-- "MunifTanjim/nui.nvim",
-- -- 可选
-- "nvim-treesitter/nvim-treesitter",
-- "rcarriga/nvim-notify",
-- "nvim-tree/nvim-web-devicons",
-- },
-- opts = {
-- -- 配置放在这里
-- cn = {
-- enabled = true,
-- },
-- },
-- },
}

51
lua/plugs/dev.lua Normal file
View File

@ -0,0 +1,51 @@
return {
'lilydjwg/colorizer', -- 颜色识别
-- "rest-nvim/rest.nvim",
-- dependencies = { { "nvim-lua/plenary.nvim" } },
-- config = function()
-- require("rest-nvim").setup({
-- --- Get the same options from Packer setup
-- })
-- end
-- },
-- {
-- -- go开发
-- "ray-x/go.nvim",
-- dependencies = { -- optional packages
-- "ray-x/guihua.lua",
-- "neovim/nvim-lspconfig",
-- "nvim-treesitter/nvim-treesitter",
-- },
-- config = function()
-- require("go").setup()
-- end,
-- -- event = { "CmdlineEnter" },
-- ft = { "go", 'gomod' },
-- build = ':lua require("go.install").update_all_sync()' -- if you need to install/update all binaries
-- },
{
-- conda 环境
"kmontocam/nvim-conda",
dependencies = { "nvim-lua/plenary.nvim" },
},
{
-- 终端
'akinsho/toggleterm.nvim',
version = "*",
config = function()
require("toggleterm").setup {
-- size can be a number or function which is passed the current terminal
size = 10,
open_mapping = [[<c-t>]],
hide_numbers = true, -- hide the number column in toggleterm buffers
shade_filetypes = {},
shade_terminals = true,
shading_factor = 1, -- the degree by which to darken to terminal colour, default: 1 for dark backgrounds, 3 for light
start_in_insert = true,
insert_mappings = true, -- whether or not the open mapping applies in insert mode
persist_size = true,
direction = 'horizontal',
}
end
}
}

109
lua/plugs/edit-plugs.lua Normal file
View File

@ -0,0 +1,109 @@
return {
'vijaymarupudi/nvim-fzf', -- fzf
{
-- surround 和 wildfire 配合有神奇的效果
'tpope/vim-surround',
'gcmt/wildfire.vim',
-- 括号箭头
'yaocccc/nvim-hlchunk',
},
-- 多光标
{
'terryma/vim-multiple-cursors',
},
{
--格式整理
{
'junegunn/vim-easy-align',
config = function()
G.map({
{ "v", "ga", ":EasyAlign<CR>", { noremap = true } },
{ "v", "=", ":EasyAlign<CR>", { noremap = true } },
})
end
},
},
{
-- 注释插件
{
'tpope/vim-commentary',
}
},
{
'github/copilot.vim', -- github copilot
},
{
'windwp/nvim-autopairs',
event = "InsertEnter",
opts = {}, -- this is equalent to setup({}) function
config = function()
require('nvim-autopairs').setup({
disable_filetype = { "vim" },
})
end
},
-- {
-- 'kevinhwang91/nvim-ufo',
-- dependencies = {
-- 'kevinhwang91/promise-async'
-- },
-- config = function ()
-- require("ufo").setup()
-- end
-- },
{
-- hop
"phaazon/hop.nvim",
branch = "v2",
keys = {
"f", "F", "t", "T",
"<c-f>"
},
lazy = true,
config = function()
require("hop").setup { keys = 'asdfghjkl;' }
local hop = require('hop')
local directions = require('hop.hint').HintDirection
G.map({
{ "n", "<c-f>", ":HopChar2MW<CR>", { noremap = true } },
{ "n", "f",
function()
hop.hint_char1({
direction = directions.AFTER_CURSOR,
current_line_only = true
})
end, { noremap = true }
},
{ "n", "F",
function()
hop.hint_char1({
direction = directions.BEFORE_CURSOR,
current_line_only = true
})
end, { noremap = true }
},
{ "n", "t",
function()
hop.hint_char1({
direction = directions.AFTER_CURSOR,
current_line_only = true,
hint_offset = -1
})
end, { noremap = true }
},
{ "n", "T",
function()
hop.hint_char1({
direction = directions.BEFORE_CURSOR,
current_line_only = true,
hint_offset = -1
})
end, { noremap = true }
},
})
end
}
}

View File

@ -0,0 +1,352 @@
return {
{
-- lsp的config
"neovim/nvim-lspconfig",
dependencies = {
"folke/neodev.nvim",
},
config = function()
require 'neodev'.setup {}
-- 定义需要启用的服务器列表及其对应的配置
local servers = {
lua_ls = require('lsp.lua'),
clangd = require('lsp.c'),
bashls = require('lsp.bash'),
basedpyright = require('lsp.basedpyright'),
yamlls = require('lsp.yaml'),
gopls = require('lsp.go'),
jsonls = require('lsp.json'),
}
-- 使用 Neovim 0.11+ 的新 API 进行配置和启用
for server, config in pairs(servers) do
vim.lsp.config(server, config)
vim.lsp.enable(server)
end
-- require 'lspconfig'.lua_ls.setup(require('lsp.lua'))
-- require 'lspconfig'.clangd.setup(require('lsp.c'))
-- require 'lspconfig'.bashls.setup(require('lsp.bash'))
-- require 'lspconfig'.basedpyright.setup(require('lsp.basedpyright'))
-- require 'lspconfig'.yamlls.setup(require('lsp.yaml'))
-- require 'lspconfig'.gopls.setup(require('lsp.go'))
-- require 'lspconfig'.jsonls.setup(require('lsp.json'))
G.map({
{ 'n', '<leader>rn', '<cmd>lua vim.lsp.buf.rename()<CR>' },
{ 'n', 'gd', '<cmd>lua vim.lsp.buf.definition()<CR>' },
{ 'n', 'gh', '<cmd>lua vim.lsp.buf.hover()<CR>' },
{ 'n', 'gD', '<cmd>lua vim.lsp.buf.declaration()<CR>' },
{ 'n', 'gi', '<cmd>lua vim.lsp.buf.implementation()<CR>' },
{ 'n', 'gr', '<cmd>lua vim.lsp.buf.references()<CR>' },
{ 'n', '<cs-i>', '<cmd>lua vim.lsp.buf.format()<CR>' },
})
end
},
{
-- lsp
"williamboman/mason-lspconfig.nvim",
dependencies = {
"williamboman/mason.nvim", -- lsp 下载器
},
config = function()
require "mason".setup {
ui = {
icons = {
package_installed = "",
package_pending = "",
package_uninstalled = ""
}
}
}
require("mason-lspconfig").setup({
ensure_installed = {
"bashls",
-- "basedpyright",
"lua_ls",
"jsonls",
"yamlls",
}
})
end
},
{ --
'simrat39/symbols-outline.nvim',
config = function()
local opts = {
highlight_hovered_item = true,
show_guides = true,
auto_preview = false,
position = 'right',
relative_width = true,
width = 25,
auto_close = false,
show_numbers = false,
show_relative_numbers = false,
show_symbol_details = true,
preview_bg_highlight = 'Pmenu',
autofold_depth = nil,
auto_unfold_hover = true,
fold_markers = { '', '' },
wrap = false,
keymaps = { -- These keymaps can be a string or a table for multiple keys
close = { "<Esc>", "q" },
goto_location = "<Cr>",
focus_location = "h",
hover_symbol = "<C-space>",
toggle_preview = "K",
rename_symbol = "r",
code_actions = "a",
fold = "o",
unfold = "l",
fold_all = "W",
unfold_all = "E",
fold_reset = "R",
},
lsp_blacklist = {},
symbol_blacklist = {},
symbols = {
File = { icon = "", hl = "@text.uri" },
Module = { icon = "", hl = "@namespace" },
Namespace = { icon = "", hl = "@namespace" },
Package = { icon = "", hl = "@namespace" },
Class = { icon = "𝓒", hl = "@type" },
Method = { icon = "ƒ", hl = "@method" },
Property = { icon = "", hl = "@method" },
Field = { icon = "", hl = "@field" },
Constructor = { icon = "", hl = "@constructor" },
Enum = { icon = "", hl = "@type" },
Interface = { icon = "", hl = "@type" },
Function = { icon = "", hl = "@function" },
Variable = { icon = "", hl = "@constant" },
Constant = { icon = "", hl = "@constant" },
String = { icon = "𝓐", hl = "@string" },
Number = { icon = "#", hl = "@number" },
Boolean = { icon = "", hl = "@boolean" },
Array = { icon = "", hl = "@constant" },
Object = { icon = "⦿", hl = "@type" },
Key = { icon = "🔐", hl = "@type" },
Null = { icon = "NULL", hl = "@type" },
EnumMember = { icon = "", hl = "@field" },
Struct = { icon = "𝓢", hl = "@type" },
Event = { icon = "🗲", hl = "@type" },
Operator = { icon = "+", hl = "@operator" },
TypeParameter = { icon = "𝙏", hl = "@parameter" },
Component = { icon = "", hl = "@function" },
Fragment = { icon = "", hl = "@constant" },
},
}
require("symbols-outline").setup(opts)
G.map({
{ "n", "<cs-o>", "<cmd>SymbolsOutline<cr>", { noremap = true } },
})
end
},
{
-- lsp补全
{
"hrsh7th/nvim-cmp",
dependencies = {
'hrsh7th/cmp-nvim-lsp', -- { name = 'nvim_lua' }
'hrsh7th/cmp-buffer', -- { name = 'buffer' },
'hrsh7th/cmp-path', -- { name = 'path' }
'hrsh7th/cmp-cmdline', -- { name = 'cmdline' }
{
'hrsh7th/vim-vsnip',
config = function()
G.g.vsnip_snippet_dir = G.fn.stdpath("config") .. "/snippets"
end
},
'hrsh7th/cmp-vsnip',
'onsails/lspkind-nvim',
},
config = function()
local has_words_before = function()
unpack = unpack or table.unpack
local line, col = unpack(G.api.nvim_win_get_cursor(0))
return col ~= 0 and G.api.nvim_buf_get_lines(0, line - 1, line, true)[1]:sub(col, col):match('%s') == nil
end
local feedkey = function(key, mode)
vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes(key, true, true, true), mode, true)
end
local cmp = require('cmp')
local cmp_opt = {
snippet = {
expand = function(args)
vim.fn["vsnip#anonymous"](args.body) -- For `vsnip` users.
end,
},
sources = cmp.config.sources({
{ name = 'nvim_lsp' },
{ name = 'vsnip' },
},
{ name = 'buffer' },
{ name = 'path' }
),
mapping = cmp.mapping.preset.insert({
["<CR>"] = cmp.mapping({
i = cmp.mapping.confirm({ behavior = cmp.ConfirmBehavior.Replace, select = true }),
c = function(fallback)
if cmp.visible() then
cmp.confirm({ behavior = cmp.ConfirmBehavior.Replace, select = true })
else
fallback()
end
end,
}),
["<Tab>"] = cmp.mapping(function(fallback)
if cmp.visible() then
cmp.select_next_item()
elseif vim.fn["vsnip#available"](1) == 1 then
feedkey("<Plug>(vsnip-expand-or-jump)", "")
elseif has_words_before() then
cmp.complete()
else
fallback()
end
end, { "i", "s" }),
["<S-Tab>"] = cmp.mapping(function()
if cmp.visible() then
cmp.select_prev_item()
elseif vim.fn["vsnip#jumpable"](-1) == 1 then
feedkey("<Plug>(vsnip-jump-prev)", "")
end
end, { "i", "s" }),
}),
window = {
completion = {
winhighlight = "Normal:Pmenu,FloatBorder:Pmenu,Search:None",
col_offset = -3,
side_padding = 0,
border = "rounded",
scrollbar = false,
},
documentation = {
winhighlight = "Normal:Pmenu,FloatBorder:Pmenu,Search:None",
border = "rounded",
scrollbar = false,
},
},
formatting = {
fields = { "kind", "abbr", "menu" },
format = function(entry, vim_item)
local kind = require("lspkind").cmp_format({ mode = "symbol_text", maxwidth = 50, })(entry, vim_item)
local strings = vim.split(kind.kind, "%s", { trimempty = true })
kind.kind = " " .. (strings[1] or "") .. " "
kind.menu = " (" .. (strings[2] or "") .. ")"
return kind
end,
},
}
require('cmp').setup(cmp_opt)
end,
},
{
-- 语法高亮 --
'nvim-treesitter/nvim-treesitter',
config = function()
local treesitter_opt = {
ensure_installed = {
-- "c",
-- "cpp",
-- "python",
-- "java",
-- "lua",
-- "bash",
-- "vimdoc",
},
indent = { enable = true },
ignore_install = {
"txt",
"go"
},
sync_install = false,
auto_install = true,
highlight = {
enable = true,
disable = function(_, buf)
local max_filesize = 100 * 1024 -- 100 KB
local ok, stats = pcall(G.loop.fs_stat, G.api.nvim_buf_get_name(buf))
if ok and stats and stats.size > max_filesize then
return true
end
end,
additional_vim_regex_highlighting = false,
},
parsers = {
html = {
install_info = {
url = "https://github.com/ikatyang/tree-sitter-vue",
files = { "src/parser.c" },
branch = "main"
}
}
}
}
require 'nvim-treesitter'.setup(treesitter_opt)
require 'nvim-treesitter.install'.prefer_git = true
if G.use_ssh then
local parsers = require 'nvim-treesitter.parsers'.get_parser_configs()
for _, p in pairs(parsers) do
p.install_info.url = p.install_info.url:gsub("https://github/com/", "git@github.com:")
end
end
end
},
},
{
'fgheng/winbar.nvim',
config = function()
require('winbar').setup({
enabled = true, -- 是否启动winbar
-- show_file_path = true, -- 是否显示文件路径
show_symbols = true, -- 是否显示函数标签
-- 颜色配置,为空,将使用默认配色
colors = {
path = '#aaffff', -- 路径的颜色,比如#ababab
file_name = '#bbbbff', -- 文件名称的颜色,比如#acacac
symbols = '#aaffaa', -- 函数颜色
},
-- 图标配置
icons = {
seperator = '>', -- 路径分割符号
editor_state = '',
lock_icon = '',
},
-- 关闭winbar的窗口
exclude_filetype = {
'help',
'startify',
'dashboard',
'packer',
'neogitstatus',
'NvimTree',
'Trouble',
'alpha',
'lir',
'Outline',
'spectre_panel',
'toggleterm',
'qf',
}
})
end
},
}

40
lua/plugs/nvimtree.lua Normal file
View File

@ -0,0 +1,40 @@
return {
"nvim-tree/nvim-tree.lua",
dependencies = {
'kyazdani42/nvim-web-devicons'
},
keys = {
"<c-e>"
},
config = function()
require 'nvim-web-devicons'.setup {}
require 'nvim-tree'.setup {
sort_by = "case_sensitive",
view = {
width = 30,
},
filters = { dotfiles = true, },
git = { enable = true },
on_attach = function(bufnr)
local api = require 'nvim-tree.api'
api.config.mappings.default_on_attach(bufnr)
-- override a default
G.map({
{ 'n', 'v', api.node.open.vertical, { buffer = bufnr } },
{ 'n', 's', api.node.open.horizontal, { buffer = bufnr } },
})
G.delmap({
{ 'n', '<C-e>', { buffer = bufnr } },
{ 'n', '<C-v>', { buffer = bufnr } },
{ 'n', '<C-x>', { buffer = bufnr } },
})
end
}
G.map({
{ "n", "<C-e>", ":NvimTreeToggle<CR>", { noremap = true } },
})
end
}

82
lua/plugs/theme.lua Normal file
View File

@ -0,0 +1,82 @@
return {
{
'folke/tokyonight.nvim',
config = function()
G.cmd("colorscheme tokyonight") -- 主题
G.opt.background = 'dark' -- 背景
end
},
{
-- line插件
'kdheepak/tabline.nvim',
'nvim-lualine/lualine.nvim',
config = function()
require('lualine').setup {
options = {
icons_enabled = true,
theme = 'auto',
component_separators = { left = '', right = '' },
section_separators = { left = '', right = '' },
disabled_filetypes = {
statusline = {},
winbar = {},
},
ignore_focus = {},
always_divide_middle = true,
globalstatus = false,
refresh = {
statusline = 1000,
tabline = 1000,
winbar = 1000,
}
},
sections = {
lualine_a = { 'mode' },
lualine_b = { 'branch', 'diff', 'diagnostics' },
lualine_c = { {
'filename',
file_status = false,
path = 1
} },
lualine_x = { 'encoding', 'fileformat', 'filetype' },
lualine_y = { 'progress' },
lualine_z = {}
},
inactive_sections = {
lualine_a = {},
lualine_b = {},
lualine_c = { 'filename' },
lualine_x = { 'location' },
lualine_y = {},
lualine_z = {}
},
tabline = {},
winbar = {},
inactive_winbar = {},
extensions = {}
}
require('tabline').setup {
-- Defaults configuration options
enable = true,
options = {
-- If lualine is installed tabline will use separators configured in lualine by default.
-- These options can be used to override those settings.
section_separators = { ' ', ' ' },
component_separators = { '', '' },
max_bufferline_percent = 66, -- set to nil by default, and it uses vim.o.columns * 2/3
show_tabs_always = false, -- this shows tabs only when there are more than one tab or if the first tab is named
show_devicons = true, -- this shows devicons in buffer section
show_bufnr = false, -- this appends [bufnr] to buffer section,
show_filename_only = true, -- shows base filename only instead of relative path in filename
modified_icon = "+ ", -- change the default modified icon
modified_italic = false, -- set to true by default; this determines whether the filename turns italic if modified
show_tabs_only = false, -- this shows only tabs instead of tabs + buffers
},
G.cmd [[
set guioptions-=e " Use showtabline in gui vim
set sessionoptions+=tabpages,globals " store tabpages and globals in session
]]
}
end,
},
}

4
lua/vsc.lua Normal file
View File

@ -0,0 +1,4 @@
require 'lazy'.setup({
require('plugs.edit-plugs'),
-- 'vijaymarupudi/nvim-fzf', -- fzf
})

175
main.py
View File

@ -1,175 +0,0 @@
import datetime
from pathlib import Path
from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.loadData import load_data
from Qfunctions.saveToXlsx import save_to_xlsx as save_to_xlsx
def _to_builtin(value):
if isinstance(value, dict):
return {str(k): _to_builtin(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_to_builtin(v) for v in value]
if hasattr(value, 'item'):
try:
return value.item()
except Exception:
return str(value)
return value
def _yaml_scalar(value):
if value is None:
return 'null'
if isinstance(value, bool):
return 'true' if value else 'false'
if isinstance(value, (int, float)):
return str(value)
text = str(value).replace('"', '\\"')
return f'"{text}"'
def _yaml_lines(value, indent=0):
space = ' ' * indent
if isinstance(value, dict):
if not value:
return [space + '{}']
lines = []
for k, v in value.items():
key = str(k)
if isinstance(v, (dict, list, tuple)):
lines.append(f'{space}{key}:')
lines.extend(_yaml_lines(v, indent + 2))
else:
lines.append(f'{space}{key}: {_yaml_scalar(v)}')
return lines
if isinstance(value, (list, tuple)):
if not value:
return [space + '[]']
lines = []
for item in value:
if isinstance(item, (dict, list, tuple)):
lines.append(f'{space}-')
lines.extend(_yaml_lines(item, indent + 2))
else:
lines.append(f'{space}- {_yaml_scalar(item)}')
return lines
return [space + _yaml_scalar(value)]
def _save_yaml(file_path, data):
built_data = _to_builtin(data)
text = '\n'.join(_yaml_lines(built_data)) + '\n'
with open(file_path, 'w', encoding='utf-8') as f:
f.write(text)
def main():
# 输入元数据文件夹名称
projet_name = '20260512 Graps'
# 请在[]内输入每一个分类的名称
label_names = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'] # label_names是大写的A-I
hidden_layers = [256, 256]
test_size = 0.5
dropout_rate = 0
epochs = 300
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
run_project_name = f'{projet_name}/{run_id}'
run_dir = Path('Result') / projet_name / run_id
run_dir.mkdir(parents=True, exist_ok=True)
print(label_names)
data = load_data(projet_name, label_names)
model = Qmlp(
data=data,
labels=label_names,
hidden_layers=hidden_layers,
test_size=test_size,
dropout_rate=dropout_rate
)
# model = QCNN(
# data=data,
# labels=label_names,
# conv_channels=(16, 32),
# kernel_size=3,
# hidden_size=128,
# test_size=0.3,
# dropout_rate=0
# )
pca_2d, pca_3d = model.get_PCA()
model.fit(epochs)
cm = model.get_cm()
cmn = model.get_cmn()
epoch_data = model.get_epoch_data()
save_to_xlsx(project_name=run_project_name, file_name='pca_2d', data=pca_2d)
save_to_xlsx(project_name=run_project_name, file_name='pca_3d', data=pca_3d)
save_to_xlsx(project_name=run_project_name, file_name='cm', data=cm)
save_to_xlsx(project_name=run_project_name, file_name='cmn', data=cmn)
save_to_xlsx(project_name=run_project_name, file_name='acc_and_loss', data=epoch_data)
run_params = {
'run_id': run_id,
'project_name': projet_name,
'label_names': label_names,
'model': {
'type': 'Qmlp',
'hidden_layers': hidden_layers,
'dropout_rate': dropout_rate,
},
'train': {
'epochs': epochs,
'test_size': test_size,
},
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
if not epoch_data.empty:
best_idx = int(epoch_data['test_accuracy'].idxmax())
best_row = epoch_data.loc[best_idx].to_dict()
last_row = epoch_data.iloc[-1].to_dict()
run_result = {
'run_id': run_id,
'result_dir': str(run_dir),
'best_epoch': int(best_row.get('epoch', 0)),
'best_metrics': best_row,
'last_epoch_metrics': last_row,
'generated_files': [
'pca_2d.xlsx',
'pca_2d.png',
'pca_3d.xlsx',
'pca_3d.png',
'cm.xlsx',
'cm.png',
'cmn.xlsx',
'cmn.png',
'acc_and_loss.xlsx',
'acc_and_loss_epoch.png',
'acc_and_loss_last_epoch_bar.png',
],
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
else:
run_result = {
'run_id': run_id,
'result_dir': str(run_dir),
'message': 'No epoch data generated.',
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
_save_yaml(run_dir / 'run_params.yaml', run_params)
_save_yaml(run_dir / 'run_result.yaml', run_result)
print(f'Done. Result saved in: {run_dir}')
if __name__ == '__main__':
main()

26
snippets/c.json Normal file
View File

@ -0,0 +1,26 @@
{
"for": {
"prefix": "for",
"body": [
" for(int $1 = $2; $1 < $3; $1+=$4) {",
" $5",
" }"
]
},
"for1": {
"prefix": "for1",
"body": [
" for(int $1 = $2; $1 < $3; $1++) {",
" $4",
" }"
]
},
"fori": {
"prefix": "fori",
"body": [
" for(int i = $1; i < $2; i++) {",
" $3",
" }"
]
}
}

View File

@ -0,0 +1,303 @@
{
"Class_Array_tree": {
"prefix": "Class_Array_tree",
"body": [
"",
"template <class T>",
"class Array_tree {",
" public:",
" Array_tree() {}",
" Array_tree(int n) { this->n = n, tree = vector<T>(n + 1); }",
" void add(int id, T key) {",
" for (int i = id; i <= n; i += lowbit(i)) tree[i] += key;",
" }",
"",
" T get_sum(int id) {",
" T sum = 0;",
" for (int i = id; i; i -= lowbit(i)) sum += tree[i];",
" return sum;",
" }",
"",
" T get_sum(int l, int r) { return get_sum(r) - get_sum(l - 1); }",
"",
" private:",
" int n;",
" vector<T> tree;",
" int lowbit(int x) { return x & -x; }",
"};",
""
]
},
"Class_SGM_Tree": {
"prefix": "Class_SGM_Tree",
"body": [
"",
"class SGM_Tree {",
" public:",
" class point {",
" public:",
" ll sum, maxi, mini;",
" };",
"",
" vll a, lazy;",
" int n;",
" ll sum, maxi, mini;",
" vector<point> tree;",
"",
" SGM_Tree() {}",
" SGM_Tree(int n, vi v) {",
" // a下标默认从1开始,只需开n个点不需要n + 1",
" this->n = n;",
" lazy = vll(n * 4 + 1);",
" a.push_back(0);",
" for (int i = 1; i <= n; i++) a.push_back(v[i]);",
" tree = vector<point>(4 * n + 1), build(1, n, 1);",
" }",
" SGM_Tree(int n, vll v) {",
" // a下标默认从1开始,只需开n个点不需要n + 1",
" this->n = n;",
" lazy = vll(n * 4 + 1);",
" a.push_back(0);",
" for (int i = 1; i <= n; i++) a.push_back(v[i]);",
" tree = vector<point>(4 * n + 1), build(1, n, 1);",
" }",
" SGM_Tree(int n, int* v) {",
" // a下标默认从1开始,只需开n个点不需要n + 1",
" this->n = n;",
" lazy = vll(n * 4 + 1);",
" a.push_back(0);",
" for (int i = 1; i <= n; i++) a.push_back(v[i]);",
" tree = vector<point>(4 * n + 1), build(1, n, 1);",
" }",
"",
" void push_up(int k) {",
" int l = k * 2, r = k * 2 + 1;",
" tree[k].sum = tree[l].sum + tree[r].sum;",
" tree[k].maxi = max(tree[l].maxi, tree[r].maxi);",
" tree[k].mini = min(tree[l].mini, tree[r].mini);",
" }",
"",
" void push_down(int l, int r, int k) {",
" if (lazy[k]) {",
" int mid = l + r >> 1;",
" lazy[k * 2] += lazy[k];",
" lazy[k * 2 + 1] += lazy[k];",
" tree[k * 2].sum += lazy[k] * (mid - l + 1);",
" tree[k * 2 + 1].sum += lazy[k] * (r - mid);",
" tree[k * 2].maxi += lazy[k];",
" tree[k * 2 + 1].maxi += lazy[k];",
" tree[k * 2].mini += lazy[k];",
" tree[k * 2 + 1].mini += lazy[k];",
" lazy[k] = 0;",
" }",
" }",
"",
" void get_updata(int l, int r, int k, ll value) {",
" tree[k].sum += value * (r - l + 1);",
" tree[k].maxi += value;",
" tree[k].mini += value;",
" lazy[k] += value;",
" }",
"",
" void get(int k) {",
" sum += tree[k].sum;",
" maxi = max(maxi, tree[k].maxi);",
" mini = min(mini, tree[k].mini);",
" }",
"",
" void build(int l, int r, int k) {",
" if (l == r) {",
" tree[k].maxi = tree[k].mini = tree[k].sum = a[l];",
" return;",
" }",
" int mid = l + r >> 1;",
" build(l, mid, k * 2);",
" build(mid + 1, r, k * 2 + 1);",
" push_up(k);",
" }",
"",
" void updata(int l, int r, int L, int R, int k, ll value) {",
" if (L <= l && r <= R) {",
" get_updata(l, r, k, value);",
" return;",
" }",
" push_down(l, r, k);",
" int mid = l + r >> 1;",
" if (L <= mid) updata(l, mid, L, R, k * 2, value);",
" if (R > mid) updata(mid + 1, r, L, R, k * 2 + 1, value);",
" push_up(k);",
" }",
"",
" void query(int l, int r, int L, int R, int k) {",
" if (L <= l && r <= R) {",
" get(k);",
" return;",
" }",
" push_down(l, r, k);",
" int mid = l + r >> 1;",
" if (mid >= L) query(l, mid, L, R, 2 * k);",
" if (mid < R) query(mid + 1, r, L, R, 2 * k + 1);",
" }",
"",
" ll get_sum(int L, int R) {",
" sum = 0;",
" query(1, n, L, R, 1);",
" return sum;",
" }",
"",
" ll get_max(int L, int R) {",
" maxi = -inf;",
" query(1, n, L, R, 1);",
" return maxi;",
" }",
"",
" ll get_min(int L, int R) {",
" mini = inf;",
" query(1, n, L, R, 1);",
" return mini;",
" }",
"};",
""
]
},
"Class_Dsu": {
"prefix": "Class_Dsu",
"body": [
"",
"class Dsu {",
" public:",
"",
" vll fa, num;",
"",
" Dsu(int n) { fa = vll(n + 1), num = vll(n + 1); }",
" int find(int x) {",
" if (!fa[x]) return x;",
" return fa[x] = find(fa[x]);",
" }",
"",
" bool Dunion(int p, int q) {",
" int v = find(p), u = find(q);",
" if (v == u) return 0;",
" fa[u] = v;",
" num[v] += num[u];",
" num[u] = num[v];",
" return 1;",
" }",
" ",
"};",
"",
"ll num(int x) { return num[find(x)]; }",
""
]
},
"class_Stmap": {
"prefix": "Class_StMap",
"body": [
"",
"class st_map {",
" public:",
" st_map() {}",
" st_map(vll v) {",
" this->n = v.size(), this->a = v;",
" this->st = vector<array<ll, 31>>(n + 1);",
" st_init();",
" }",
" int query(int l, int r) {",
" int len = r - l + 1;",
" int k = log(len) / log(2);",
" return max(st[l][k], st[r - (1 << k) + 1][k]);",
" }",
"",
" private:",
" int n;",
" vll a;",
" vector<array<ll, 31>> st;",
" void st_init() {",
" for (int j = 0; j <= 17; j++) {",
" for (int i = 1; i + (1 << j) - 1 <= n; i++) {",
" if (j == 0)",
" st[i][j] = a[i];",
" else",
" st[i][j] = max(st[i][j - 1], st[i + (1 << j - 1)][j - 1]);",
" }",
" }",
" }",
"};",
""
]
},
"Class_HJT_tree": {
"prefix": "Class_HJT_tree",
"body": [
"",
"template <class T>",
"class HJT_tree {",
" //处理数据默认下标从1开始",
" public:",
" //构造函数",
" HJT_tree() {}",
" HJT_tree(vector<T> v) {",
" base = v, this->n = base.size() - 1;",
" tree = vector<node>(n * 32), root.push_back(build(1, n));",
" }",
"",
" void updata(int v, int x, T value) {",
" //插入函数(版本,修改位置,修改值)",
" root.push_back(insert(root[v], 1, n, x, value));",
" }",
"",
" T query(int v, int x) {",
" //查询函数(版本,查询位置)",
" return get_se(root[v], 1, n, x, x);",
" }",
"",
" T query(int v, int l, int r) {",
" //查询函数(版本,查询区间)",
" return get_se(root[v], 1, n, l, r);",
" }",
"",
" private:",
" vi root;",
" vector<T> base;",
" int n, idx = 0;",
" struct node {",
" int l, r;",
" T data;",
" };",
" vector<node> tree;",
" void pushup(int q) { tree[q].data = op(tree[q].l, tree[q].r); }",
" T op(int l, int r) { return max(tree[l].data, tree[r].data); }",
" T e() { return -inf; }",
" int build(int l, int r) {",
" int now = ++idx, mid = l + r >> 1;",
" if (l != r)",
" tree[now].l = build(l, mid), tree[now].r = build(mid + 1, r), pushup(now);",
" return now;",
" }",
" int insert(int old, int l, int r, int x, int value) {",
" int now = ++idx, mid = l + r >> 1;",
" tree[now] = tree[old];",
" if (l == r)",
" tree[now].data = value;",
" else {",
" if (x <= mid)",
" tree[now].l = insert(tree[old].l, l, mid, x, value);",
" else",
" tree[now].r = insert(tree[old].r, mid + 1, r, x, value);",
" pushup(now);",
" }",
" return now;",
" }",
" T get_se(int v, int l, int r, int L, int R) {",
" if (L <= l && r <= R) return tree[v].data;",
" ll mid = l + r >> 1;",
" T res = e();",
" if (L <= mid) res = max(res, get_se(tree[v].l, l, mid, L, R));",
" if (R > mid) res = max(res, get_se(tree[v].r, mid + 1, r, L, R));",
" return res;",
" }",
"};",
""
]
}
}

View File

@ -0,0 +1,359 @@
{
"Graph": {
"prefix": "Class_Graph",
"body": [
"",
"template <class T>",
"class Graph {",
" public:",
" struct edge {",
" int next, to;",
" T w;",
" };",
" int n;",
" vector<vector<T>> maps;",
" vector<T> dis;",
" vector<edge> e;",
" vi head, bj;",
" Graph() {}",
"",
" Graph(int n) {",
" this->n = n, this->m = n * (n - 1);",
" head = vi(n + 1, -1), e = vector<edge>(m * 2 + 1);",
" }",
"",
" Graph(int n, int m) {",
" this->n = n, this->m = m, head = vi(n + 1, -1), e = vector<edge>(m * 2 + 1);",
" }",
"",
" void add(int u, int v, T w) {",
" e[cnt].to = v, e[cnt].next = head[u], e[cnt].w = w, head[u] = cnt++;",
" }",
"",
" void add(int u, int v) {",
" e[cnt].to = v, e[cnt].next = head[u], head[u] = cnt++;",
" }",
"",
" private:",
" int m, cnt = 0;",
"};",
"#define e g.e",
"#define head g.head",
"#define maps g.maps",
"#define add g.add",
""
]
},
"Class_Graph_Dijkstra": {
"prefix": "Class_Graph_Dijkstra",
"body": [
"",
"#define dis g.dis",
"#define bj g.bj",
"ll dijkstra(Graph<ll> g, int st, int en) {",
" dis = vll(g.n + 1, inf),",
" bj = vi(g.n + 1);",
" priority_queue<pll, vector<pll>, greater<pll>> q;",
" dis[st] = 0;",
" q.push({0, st});",
" while (!q.empty()) {",
" int u = q.top().y;",
" q.pop();",
" if (bj[u]) continue;",
" bj[u] = 1;",
" for (int i = head[u]; ~i; i = e[i].next) {",
" int v = e[i].to;",
" if (dis[v] > dis[u] + e[i].w) {",
" dis[v] = dis[u] + e[i].w;",
" q.push({dis[v], v});",
" }",
" }",
" }",
" if (dis[en] != inf)",
" return dis[en];",
" else",
" return -1;",
"}",
""
]
},
"Class_Graph_SPFA": {
"prefix": "Class_Graph_SPFA",
"body": [
"",
"ll spfa(Graph<ll> g, int st, int en) {",
" vll bj(g.n + 1), dis(g.n + 1, inf), num(g.n + 1);",
" queue<int> q;",
" dis[st] = 0;",
" bj[st] = 1;",
" q.push(st);",
" while (q.size()) {",
" int u = q.front();",
" if (!bj[u]) continue;",
" if (num[u] > g.n + 1) return inf;//判断是否产生负环",
" bj[u] = 0, num[u]++;",
" for (int i = head[u]; ~i; i = e[i].next) {",
" int w = e[i].w, v = e[i].to;",
" if (dis[v] > dis[u] + w) {",
" dis[v] = dis[u] + w;",
" if (!bj[v]) q.push(v);",
" }",
" }",
" }",
" if(dis[en] == inf) return -1;",
" return dis[en];",
"}",
""
]
},
"Class_Graph_Kurskal": {
"prefix": "Class_Graph_Kurskal",
"body": [
"",
"ll kruskal(Graph<ll> g, Dsu d) {",
" ll res = 0, cnt = 0;",
" sort(e.begin(), e.end());",
" for (auto i : e) {",
" int u = i.u, v = i.v, w = i.w;",
" int q = d.find(i.u), p = d.find(i.v);",
" if (d.Dunion(q, p)) cnt++, res += w;",
" if (cnt == n - 1) break;",
" }",
" return res;",
"}",
""
]
},
"Class_Graph_Prim": {
"prefix": "Class_Graph_Prim",
"body": []
},
"Class_Graph_TreeDiam": {
"prefix": "Class_Graph_TreeDiam",
"description": "树的直径",
"body": [
"ll TreeDiam(Graph<ll> g) {",
" vll dis1(n + 1), dis2(n + 1), p(n + 1), up(n + 1);",
" function<ll(ll, ll)> dfs = [&](ll u, ll fa) {",
" for (int i = head[u]; ~i; i = e[i].next) {",
" int v = e[i].to, w = 1;",
" if (v == fa) continue;",
" ll x = dfs(v, u) + 1;",
" if (x >= dis1[u])",
" dis2[u] = dis1[u], dis1[u] = x, p[u] = v;",
" else if (x >= dis2[u])",
" dis2[u] = x;",
" }",
" return dis1[u];",
" };",
" function<void(ll, ll)> dfs0 = [&](ll u, ll fa) {",
" for (int i = head[u]; ~i; i = e[i].next) {",
" int v = e[i].to, w = 1;",
" if (v == fa) continue;",
" if (p[u] == v)",
" up[v] = max(dis2[u], up[u]) + w;",
" else",
" up[v] = max(dis1[u], up[u]) + w;",
" dfs0(v, u);",
" }",
" };",
" dfs(1, -1), dfs0(1, -1);",
" ll ans = -1;",
" for (int i = 1; i <= n + m; i++) {",
" if (dis1[i] == dis2[i] && dis1[i] == 0)",
" ans = max(ans, up[i]);",
" else",
" ans = max(ans, max(up[i], dis1[i]));",
" }",
" return ans;",
"}"
]
},
"Class_Graph_SCC": {
"prefix": "Class_Graph_SCC",
"body": [
"",
"class SCC {",
" public:",
" stack<ll> stk;",
" int timestamp = 0, scc_cnt = 0, n, m;",
" vll dfn, low, in_stk, Size, id;",
" Graph<ll> g;",
"",
" SCC() {}",
"",
" SCC(Graph<ll> g) {",
" this->n = g.n, this->m = g.m, this->g = g;",
" dfn = vll(n + 1, 0);",
" low = vll(n + 1, 0);",
" in_stk = vll(n + 1, 0);",
" Size = vll(n + 1, 0);",
" id = vll(n + 1, 0);",
" for (int i = 1; i <= n; i++)",
" if (!dfn[i]) tarjan(i);",
" }",
"",
" void tarjan(int u) {",
" dfn[u] = low[u] = ++timestamp;",
" stk.push(u), in_stk[u] = 1;",
" for (int i = head[u]; ~i; i = e[i].next) {",
" int v = e[i].to;",
" if (!dfn[v]) {",
" tarjan(v);",
" low[u] = min(low[u], low[v]);",
" } else if (in_stk[v])",
" low[u] = min(low[u], dfn[v]);",
" }",
" if (low[u] == dfn[u]) {",
" ++scc_cnt;",
" int v;",
" do {",
" v = stk.top();",
" stk.pop();",
" in_stk[v] = 0;",
" id[v] = scc_cnt;",
" Size[scc_cnt]++;",
" } while (v != u);",
" }",
" return;",
" }",
"};",
""
]
},
"Class_Graph_Euler": {
"prefix": "Class_Graph_Euler",
"body": [
"",
"vll din, dout;",
"",
"class Euler {",
" public:",
" int n, m;",
" Graph<ll> g;",
" vector<bool> used;",
" vll path;",
" Euler() {}",
" Euler(Graph<ll> g) {",
" this->n = g.n, this->m = g.m, this->g = g;",
" used = vector<bool>(n + 1);",
" }",
"",
" void dfs_u(int u) {",
" //无向图",
" for (long long &i = head[u]; i; i = e[i].next) {",
" long long j = i & 1 ? i + 1 : i - 1;",
" if(used[i]) {",
" i = e[i].next;",
" continue;",
" }",
" used[j] = used[i] = true;",
" int t = i / 2 + 1;",
" dfs_u(e[i].to);",
" path.push_back(t);",
" }",
" }",
"",
" void dfs_o(int u) {",
" // 有向图",
" for (long long &i = head[u]; i; i = e[i].next) {",
" if(used[i]) {",
" i = e[i].next;",
" continue;",
" }",
" used[i] = true;",
" int t = i + 1;",
" dfs_o(e[i].to);",
" path.push_back(t);",
" }",
" }",
"",
"};",
"#define path eu.path",
""
]
},
"Class_Graph_LCA": {
"prefix": "Class_Graph_LCA",
"body": [
"",
"class LCA {",
" public:",
" int n;",
" vll depth, fa[33];",
"",
" LCA() {}",
" LCA(Graph<ll> g, int root) {",
" n = g.n, depth = vll(n + 1, inf);",
" for (int i = 0; i <= 32; i++) fa[i] = vll(n + 1);",
" bfs(g, root);",
" }",
"",
" void bfs(Graph<ll> g, int root) {",
" depth[0] = 0, depth[root] = 1;",
" queue<int> q;",
" q.push(root);",
" while (q.size()) {",
" auto u = q.front();",
" q.pop();",
" for (int i = head[u]; ~i; i = e[i].next) {",
" int v = e[i].to;",
" if (depth[v] > depth[u] + 1) {",
" depth[v] = depth[u] + 1;",
" q.push(v);",
" fa[0][v] = u;",
" for (int k = 1; k <= 32; k++) fa[k][v] = fa[k - 1][fa[k - 1][v]];",
" }",
" }",
" }",
" }",
"",
" int query(int a, int b) {",
" if (depth[a] < depth[b]) swap(a, b);",
" for (int i = 32; i >= 0; i--)",
" if (depth[fa[i][a]] >= depth[b]) a = fa[i][a];",
" if (a == b) return a;",
" for (int i = 32; i >= 0; i--)",
" if (fa[i][a] != fa[i][b]) a = fa[i][a], b = fa[i][b];",
" return fa[0][a];",
" }",
"};",
""
]
},
"Class_Graph_erfen": {
"prefix": "Class_Graph_erfen",
"body": [
"",
"class erfen_graph {",
" public:",
" ll res, n;",
" vi match, st;",
" Graph<ll> g;",
" erfen_graph(Graph<ll> g) {",
" this->g = g, this->n = g.n;",
" st = vector<int>(n + 1), match = vector<int>(n + 1);",
" res = 0;",
" for (int i = 1; i <= n; i++) {",
" st = vector<int>(n + 1);",
" if (find(i)) res++;",
" }",
" }",
"",
" bool find(int u) {",
" for (int i = head[u]; ~i; i = e[i].next) {",
" if (!st[e[i].to]) {",
" st[e[i].to] = true;",
" if (!match[e[i].to] || find(match[e[i].to])) {",
" match[e[i].to] = u;",
" return true;",
" }",
" }",
" }",
" return false;",
" }",
"};",
""
]
}
}

View File

@ -0,0 +1,108 @@
{
"Math_QuickPow": {
"prefix": "Math_QuickPow",
"body": [
"",
"ll quick_Pow(ll a, ll b, ll mod) {",
" // a的b次方模mod",
" ll res = 1, t = a;",
" while (b) {",
" if (b & 1) res = (res * t) % mod;",
" t = t * t % mod;",
" b >>= 1;",
" }",
" return res;",
"}",
"",
],
},
"Math_Fm": {
"prefix": "Math_Fm",
"body": [
"",
"ll quick_Pow(ll a, ll b, ll mod) {",
" // a的b次方模mod",
" ll res = 1, t = a;",
" while (b) {",
" if (b & 1) res = (res * t) % mod;",
" t = t * t % mod;",
" b >>= 1;",
" }",
" return res;",
"}",
"",
"ll Fm(ll a, ll mod) {",
" //费马小定理求逆元",
" return quick_Pow(a, mod - 2, mod);",
"}",
"",
],
},
"Math_C": {
"prefix": "Math_C",
"body": [
"",
"ll quick_Pow(ll a, ll b, ll mod) {",
" // a的b次方模mod",
" ll res = 1, t = a;",
" while (b) {",
" if (b & 1) res = (res * t) % mod;",
" t = t * t % mod;",
" b >>= 1;",
" }",
" return res;",
"}",
"",
"ll Fm(ll a, ll mod) {",
" //费马小定理求逆元",
" return quick_Pow(a, mod - 2, mod);",
"}",
"",
"ll C(ll n, ll m, ll mod) {",
" ll fz = 1, fm = 1;",
" for (ll i = n; i >= n - m + 1; i--) fz = fz * i % mod;",
" for (ll i = 1; i <= m; i++) fm = fm * i % mod;",
" return (fz * Fm(fm, mod)) % mod;",
"}",
"",
],
},
"Class_Math_Bignum": {
"prefix": "Class_Math_Bignum",
"body": [
"",
"class Math_Bignum {",
" public:",
" string Bignum;",
"",
" vll num;",
"",
" Math_Bignum(string s) {",
" Bignum = s;",
" for (auto i : s) num.push_back(i - '0');",
" }",
"",
" Math_Bignum(vll v) {",
" string s;",
" num = v;",
" for (auto i : v) s.push_back(i + '0');",
" }",
"",
" Math_Bignum(ll l) {",
" string s;",
" while (l) s.push_back(l % 10 + '0'), num.push_back(l % 10), l /= 10;",
" reverse(num.begin(), num.end());",
" reverse(s.begin(), s.end());",
" Bignum = s;",
" }",
"",
" ll get(ll l, ll r) {",
" ll res = 0;",
" for (int i = l - 1; i <= r - 1; i++) res = res * 10 + num[i];",
" return res;",
" }",
"};",
"",
],
},
}

43
snippets/go.json Normal file
View File

@ -0,0 +1,43 @@
{
"class": {
"prefix": "class",
"body": "type $1 struct{$2}\n$0"
},
"errnil": {
"prefix": "errnil",
"body": [
"if err != nil {",
" $1",
"}"
]
},
"init": {
"prefix": "init",
"body": [
"func init() {",
" $1",
"}"
]
},
"ctx": {
"prefix": "ctx",
"body": [
"ctx := context.Background()"
]
},
"cok": {
"prefix": "cok",
"body": [
"c.JSON(http.StatusOK, vo.Success($1))"
]
},
"cbad": {
"prefix": "cbad",
"body": [
"if err != nil {",
" c.JSON(http.StatusBadRequest, vo.Fail(err))",
" return",
"}"
]
}
}

13
snippets/json.json Normal file
View File

@ -0,0 +1,13 @@
{
"mod": {
"prefix": "mod",
"body": [
" \"$1\": {",
" \"prefix\": \"$1\",",
" \"body\": [",
" $2",
" ]",
" }"
]
}
}

3
snippets/lua.json Normal file
View File

@ -0,0 +1,3 @@
{
}