Deeplearning/main.py

176 lines
4.6 KiB
Python

import datetime
from pathlib import Path
from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.loadData import load_data
from Qfunctions.saveToXlsx import save_to_xlsx as save_to_xlsx
def _to_builtin(value):
if isinstance(value, dict):
return {str(k): _to_builtin(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_to_builtin(v) for v in value]
if hasattr(value, 'item'):
try:
return value.item()
except Exception:
return str(value)
return value
def _yaml_scalar(value):
if value is None:
return 'null'
if isinstance(value, bool):
return 'true' if value else 'false'
if isinstance(value, (int, float)):
return str(value)
text = str(value).replace('"', '\\"')
return f'"{text}"'
def _yaml_lines(value, indent=0):
space = ' ' * indent
if isinstance(value, dict):
if not value:
return [space + '{}']
lines = []
for k, v in value.items():
key = str(k)
if isinstance(v, (dict, list, tuple)):
lines.append(f'{space}{key}:')
lines.extend(_yaml_lines(v, indent + 2))
else:
lines.append(f'{space}{key}: {_yaml_scalar(v)}')
return lines
if isinstance(value, (list, tuple)):
if not value:
return [space + '[]']
lines = []
for item in value:
if isinstance(item, (dict, list, tuple)):
lines.append(f'{space}-')
lines.extend(_yaml_lines(item, indent + 2))
else:
lines.append(f'{space}- {_yaml_scalar(item)}')
return lines
return [space + _yaml_scalar(value)]
def _save_yaml(file_path, data):
built_data = _to_builtin(data)
text = '\n'.join(_yaml_lines(built_data)) + '\n'
with open(file_path, 'w', encoding='utf-8') as f:
f.write(text)
def main():
# 输入元数据文件夹名称
projet_name = '20260512 Graps'
# 请在[]内输入每一个分类的名称
label_names = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I'] # label_names是大写的A-I
hidden_layers = [256, 256]
test_size = 0.5
dropout_rate = 0
epochs = 300
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
run_project_name = f'{projet_name}/{run_id}'
run_dir = Path('Result') / projet_name / run_id
run_dir.mkdir(parents=True, exist_ok=True)
print(label_names)
data = load_data(projet_name, label_names)
model = Qmlp(
data=data,
labels=label_names,
hidden_layers=hidden_layers,
test_size=test_size,
dropout_rate=dropout_rate
)
# model = QCNN(
# data=data,
# labels=label_names,
# conv_channels=(16, 32),
# kernel_size=3,
# hidden_size=128,
# test_size=0.3,
# dropout_rate=0
# )
pca_2d, pca_3d = model.get_PCA()
model.fit(epochs)
cm = model.get_cm()
cmn = model.get_cmn()
epoch_data = model.get_epoch_data()
save_to_xlsx(project_name=run_project_name, file_name='pca_2d', data=pca_2d)
save_to_xlsx(project_name=run_project_name, file_name='pca_3d', data=pca_3d)
save_to_xlsx(project_name=run_project_name, file_name='cm', data=cm)
save_to_xlsx(project_name=run_project_name, file_name='cmn', data=cmn)
save_to_xlsx(project_name=run_project_name, file_name='acc_and_loss', data=epoch_data)
run_params = {
'run_id': run_id,
'project_name': projet_name,
'label_names': label_names,
'model': {
'type': 'Qmlp',
'hidden_layers': hidden_layers,
'dropout_rate': dropout_rate,
},
'train': {
'epochs': epochs,
'test_size': test_size,
},
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
if not epoch_data.empty:
best_idx = int(epoch_data['test_accuracy'].idxmax())
best_row = epoch_data.loc[best_idx].to_dict()
last_row = epoch_data.iloc[-1].to_dict()
run_result = {
'run_id': run_id,
'result_dir': str(run_dir),
'best_epoch': int(best_row.get('epoch', 0)),
'best_metrics': best_row,
'last_epoch_metrics': last_row,
'generated_files': [
'pca_2d.xlsx',
'pca_2d.png',
'pca_3d.xlsx',
'pca_3d.png',
'cm.xlsx',
'cm.png',
'cmn.xlsx',
'cmn.png',
'acc_and_loss.xlsx',
'acc_and_loss_epoch.png',
'acc_and_loss_last_epoch_bar.png',
],
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
else:
run_result = {
'run_id': run_id,
'result_dir': str(run_dir),
'message': 'No epoch data generated.',
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
_save_yaml(run_dir / 'run_params.yaml', run_params)
_save_yaml(run_dir / 'run_result.yaml', run_result)
print(f'Done. Result saved in: {run_dir}')
if __name__ == '__main__':
main()