176 lines
4.7 KiB
Python
176 lines
4.7 KiB
Python
import datetime
|
||
from pathlib import Path
|
||
|
||
from Qtorch.Models.Qmlp import Qmlp
|
||
from Qfunctions.loadData import load_data
|
||
from Qfunctions.saveToXlsx import save_to_xlsx as save_to_xlsx
|
||
|
||
|
||
def _to_builtin(value):
|
||
if isinstance(value, dict):
|
||
return {str(k): _to_builtin(v) for k, v in value.items()}
|
||
if isinstance(value, (list, tuple)):
|
||
return [_to_builtin(v) for v in value]
|
||
if hasattr(value, 'item'):
|
||
try:
|
||
return value.item()
|
||
except Exception:
|
||
return str(value)
|
||
return value
|
||
|
||
|
||
def _yaml_scalar(value):
|
||
if value is None:
|
||
return 'null'
|
||
if isinstance(value, bool):
|
||
return 'true' if value else 'false'
|
||
if isinstance(value, (int, float)):
|
||
return str(value)
|
||
text = str(value).replace('"', '\\"')
|
||
return f'"{text}"'
|
||
|
||
|
||
def _yaml_lines(value, indent=0):
|
||
space = ' ' * indent
|
||
|
||
if isinstance(value, dict):
|
||
if not value:
|
||
return [space + '{}']
|
||
lines = []
|
||
for k, v in value.items():
|
||
key = str(k)
|
||
if isinstance(v, (dict, list, tuple)):
|
||
lines.append(f'{space}{key}:')
|
||
lines.extend(_yaml_lines(v, indent + 2))
|
||
else:
|
||
lines.append(f'{space}{key}: {_yaml_scalar(v)}')
|
||
return lines
|
||
|
||
if isinstance(value, (list, tuple)):
|
||
if not value:
|
||
return [space + '[]']
|
||
lines = []
|
||
for item in value:
|
||
if isinstance(item, (dict, list, tuple)):
|
||
lines.append(f'{space}-')
|
||
lines.extend(_yaml_lines(item, indent + 2))
|
||
else:
|
||
lines.append(f'{space}- {_yaml_scalar(item)}')
|
||
return lines
|
||
|
||
return [space + _yaml_scalar(value)]
|
||
|
||
|
||
def _save_yaml(file_path, data):
|
||
built_data = _to_builtin(data)
|
||
text = '\n'.join(_yaml_lines(built_data)) + '\n'
|
||
with open(file_path, 'w', encoding='utf-8') as f:
|
||
f.write(text)
|
||
|
||
def main():
|
||
# 输入元数据文件夹名称
|
||
projet_name = '20260409 grap'
|
||
# 请在[]内输入每一个分类的名称
|
||
# label_names 是一个列表里面按顺序包含了小写的‘a'到‘z’
|
||
label_names = list(range(1, 10))
|
||
hidden_layers = [256, 128, 128, 128]
|
||
test_size = 0.5
|
||
dropout_rate = 0
|
||
epochs = 300
|
||
|
||
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
run_project_name = f'{projet_name}/{run_id}'
|
||
run_dir = Path('Result') / projet_name / run_id
|
||
run_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
print(label_names)
|
||
data = load_data(projet_name, label_names)
|
||
|
||
model = Qmlp(
|
||
data=data,
|
||
labels=label_names,
|
||
hidden_layers=hidden_layers,
|
||
test_size=test_size,
|
||
dropout_rate=dropout_rate
|
||
)
|
||
# model = QCNN(
|
||
# data=data,
|
||
# labels=label_names,
|
||
# conv_channels=(16, 32),
|
||
# kernel_size=3,
|
||
# hidden_size=128,
|
||
# test_size=0.3,
|
||
# dropout_rate=0
|
||
# )
|
||
|
||
pca_2d, pca_3d = model.get_PCA()
|
||
|
||
model.fit(epochs)
|
||
|
||
cm = model.get_cm()
|
||
cmn = model.get_cmn()
|
||
epoch_data = model.get_epoch_data()
|
||
|
||
save_to_xlsx(project_name=run_project_name, file_name='pca_2d', data=pca_2d)
|
||
save_to_xlsx(project_name=run_project_name, file_name='pca_3d', data=pca_3d)
|
||
save_to_xlsx(project_name=run_project_name, file_name='cm', data=cm)
|
||
save_to_xlsx(project_name=run_project_name, file_name='cmn', data=cmn)
|
||
save_to_xlsx(project_name=run_project_name, file_name='acc_and_loss', data=epoch_data)
|
||
|
||
run_params = {
|
||
'run_id': run_id,
|
||
'project_name': projet_name,
|
||
'label_names': label_names,
|
||
'model': {
|
||
'type': 'Qmlp',
|
||
'hidden_layers': hidden_layers,
|
||
'dropout_rate': dropout_rate,
|
||
},
|
||
'train': {
|
||
'epochs': epochs,
|
||
'test_size': test_size,
|
||
},
|
||
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
|
||
}
|
||
|
||
if not epoch_data.empty:
|
||
best_idx = int(epoch_data['test_accuracy'].idxmax())
|
||
best_row = epoch_data.loc[best_idx].to_dict()
|
||
last_row = epoch_data.iloc[-1].to_dict()
|
||
run_result = {
|
||
'run_id': run_id,
|
||
'result_dir': str(run_dir),
|
||
'best_epoch': int(best_row.get('epoch', 0)),
|
||
'best_metrics': best_row,
|
||
'last_epoch_metrics': last_row,
|
||
'generated_files': [
|
||
'pca_2d.xlsx',
|
||
'pca_2d.png',
|
||
'pca_3d.xlsx',
|
||
'pca_3d.png',
|
||
'cm.xlsx',
|
||
'cm.png',
|
||
'cmn.xlsx',
|
||
'cmn.png',
|
||
'acc_and_loss.xlsx',
|
||
'acc_and_loss_epoch.png',
|
||
'acc_and_loss_last_epoch_bar.png',
|
||
],
|
||
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
|
||
}
|
||
else:
|
||
run_result = {
|
||
'run_id': run_id,
|
||
'result_dir': str(run_dir),
|
||
'message': 'No epoch data generated.',
|
||
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
|
||
}
|
||
|
||
_save_yaml(run_dir / 'run_params.yaml', run_params)
|
||
_save_yaml(run_dir / 'run_result.yaml', run_result)
|
||
|
||
print(f'Done. Result saved in: {run_dir}')
|
||
|
||
if __name__ == '__main__':
|
||
main()
|