Deeplearning/main.py

176 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
from pathlib import Path
from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.loadData import load_data
from Qfunctions.saveToXlsx import save_to_xlsx as save_to_xlsx
def _to_builtin(value):
if isinstance(value, dict):
return {str(k): _to_builtin(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_to_builtin(v) for v in value]
if hasattr(value, 'item'):
try:
return value.item()
except Exception:
return str(value)
return value
def _yaml_scalar(value):
if value is None:
return 'null'
if isinstance(value, bool):
return 'true' if value else 'false'
if isinstance(value, (int, float)):
return str(value)
text = str(value).replace('"', '\\"')
return f'"{text}"'
def _yaml_lines(value, indent=0):
space = ' ' * indent
if isinstance(value, dict):
if not value:
return [space + '{}']
lines = []
for k, v in value.items():
key = str(k)
if isinstance(v, (dict, list, tuple)):
lines.append(f'{space}{key}:')
lines.extend(_yaml_lines(v, indent + 2))
else:
lines.append(f'{space}{key}: {_yaml_scalar(v)}')
return lines
if isinstance(value, (list, tuple)):
if not value:
return [space + '[]']
lines = []
for item in value:
if isinstance(item, (dict, list, tuple)):
lines.append(f'{space}-')
lines.extend(_yaml_lines(item, indent + 2))
else:
lines.append(f'{space}- {_yaml_scalar(item)}')
return lines
return [space + _yaml_scalar(value)]
def _save_yaml(file_path, data):
built_data = _to_builtin(data)
text = '\n'.join(_yaml_lines(built_data)) + '\n'
with open(file_path, 'w', encoding='utf-8') as f:
f.write(text)
def main():
# 输入元数据文件夹名称
projet_name = '20260409 grap'
# 请在[]内输入每一个分类的名称
# label_names 是一个列表里面按顺序包含了小写的a'到z
label_names = list(range(1, 10))
hidden_layers = [256, 128, 128, 128]
test_size = 0.5
dropout_rate = 0
epochs = 300
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
run_project_name = f'{projet_name}/{run_id}'
run_dir = Path('Result') / projet_name / run_id
run_dir.mkdir(parents=True, exist_ok=True)
print(label_names)
data = load_data(projet_name, label_names)
model = Qmlp(
data=data,
labels=label_names,
hidden_layers=hidden_layers,
test_size=test_size,
dropout_rate=dropout_rate
)
# model = QCNN(
# data=data,
# labels=label_names,
# conv_channels=(16, 32),
# kernel_size=3,
# hidden_size=128,
# test_size=0.3,
# dropout_rate=0
# )
pca_2d, pca_3d = model.get_PCA()
model.fit(epochs)
cm = model.get_cm()
cmn = model.get_cmn()
epoch_data = model.get_epoch_data()
save_to_xlsx(project_name=run_project_name, file_name='pca_2d', data=pca_2d)
save_to_xlsx(project_name=run_project_name, file_name='pca_3d', data=pca_3d)
save_to_xlsx(project_name=run_project_name, file_name='cm', data=cm)
save_to_xlsx(project_name=run_project_name, file_name='cmn', data=cmn)
save_to_xlsx(project_name=run_project_name, file_name='acc_and_loss', data=epoch_data)
run_params = {
'run_id': run_id,
'project_name': projet_name,
'label_names': label_names,
'model': {
'type': 'Qmlp',
'hidden_layers': hidden_layers,
'dropout_rate': dropout_rate,
},
'train': {
'epochs': epochs,
'test_size': test_size,
},
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
if not epoch_data.empty:
best_idx = int(epoch_data['test_accuracy'].idxmax())
best_row = epoch_data.loc[best_idx].to_dict()
last_row = epoch_data.iloc[-1].to_dict()
run_result = {
'run_id': run_id,
'result_dir': str(run_dir),
'best_epoch': int(best_row.get('epoch', 0)),
'best_metrics': best_row,
'last_epoch_metrics': last_row,
'generated_files': [
'pca_2d.xlsx',
'pca_2d.png',
'pca_3d.xlsx',
'pca_3d.png',
'cm.xlsx',
'cm.png',
'cmn.xlsx',
'cmn.png',
'acc_and_loss.xlsx',
'acc_and_loss_epoch.png',
'acc_and_loss_last_epoch_bar.png',
],
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
else:
run_result = {
'run_id': run_id,
'result_dir': str(run_dir),
'message': 'No epoch data generated.',
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
}
_save_yaml(run_dir / 'run_params.yaml', run_params)
_save_yaml(run_dir / 'run_result.yaml', run_result)
print(f'Done. Result saved in: {run_dir}')
if __name__ == '__main__':
main()