feat: add YAML archiving for run parameters and results with timestamped directories
This commit is contained in:
parent
9f241757c6
commit
5f58d7fb56
13
README.md
13
README.md
|
|
@ -252,6 +252,19 @@ save_to_xlsx(project_name=projet_name, file_name='cmn', data=cmn)
|
|||
save_to_xlsx(project_name=projet_name, file_name='acc_and_loss', data=epoch_data)
|
||||
```
|
||||
|
||||
### 3.6 每次运行参数与结果 YAML 存档
|
||||
|
||||
当前 `main.py` 会为每次运行创建时间戳目录:
|
||||
|
||||
- `Result/<project_name>/<YYYYMMDD_HHMMSS>/`
|
||||
|
||||
并在目录下自动保存:
|
||||
|
||||
- `run_params.yaml`:本次运行参数快照。
|
||||
- `run_result.yaml`:本次运行结果摘要(最佳轮次与最后轮次指标)。
|
||||
|
||||
这样可以直接对比不同运行的参数与结果变化。
|
||||
|
||||
## 4. load_data 参数说明
|
||||
|
||||
| 参数 | 类型 | 默认值 | 说明 |
|
||||
|
|
|
|||
151
main.py
151
main.py
|
|
@ -1,22 +1,97 @@
|
|||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from Qtorch.Models.Qmlp import Qmlp
|
||||
from Qfunctions.loadData import load_data
|
||||
from Qfunctions.saveToXlsx import save_to_xlsx as save_to_xlsx
|
||||
|
||||
|
||||
def _to_builtin(value):
|
||||
if isinstance(value, dict):
|
||||
return {str(k): _to_builtin(v) for k, v in value.items()}
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [_to_builtin(v) for v in value]
|
||||
if hasattr(value, 'item'):
|
||||
try:
|
||||
return value.item()
|
||||
except Exception:
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
|
||||
def _yaml_scalar(value):
|
||||
if value is None:
|
||||
return 'null'
|
||||
if isinstance(value, bool):
|
||||
return 'true' if value else 'false'
|
||||
if isinstance(value, (int, float)):
|
||||
return str(value)
|
||||
text = str(value).replace('"', '\\"')
|
||||
return f'"{text}"'
|
||||
|
||||
|
||||
def _yaml_lines(value, indent=0):
|
||||
space = ' ' * indent
|
||||
|
||||
if isinstance(value, dict):
|
||||
if not value:
|
||||
return [space + '{}']
|
||||
lines = []
|
||||
for k, v in value.items():
|
||||
key = str(k)
|
||||
if isinstance(v, (dict, list, tuple)):
|
||||
lines.append(f'{space}{key}:')
|
||||
lines.extend(_yaml_lines(v, indent + 2))
|
||||
else:
|
||||
lines.append(f'{space}{key}: {_yaml_scalar(v)}')
|
||||
return lines
|
||||
|
||||
if isinstance(value, (list, tuple)):
|
||||
if not value:
|
||||
return [space + '[]']
|
||||
lines = []
|
||||
for item in value:
|
||||
if isinstance(item, (dict, list, tuple)):
|
||||
lines.append(f'{space}-')
|
||||
lines.extend(_yaml_lines(item, indent + 2))
|
||||
else:
|
||||
lines.append(f'{space}- {_yaml_scalar(item)}')
|
||||
return lines
|
||||
|
||||
return [space + _yaml_scalar(value)]
|
||||
|
||||
|
||||
def _save_yaml(file_path, data):
|
||||
built_data = _to_builtin(data)
|
||||
text = '\n'.join(_yaml_lines(built_data)) + '\n'
|
||||
with open(file_path, 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
|
||||
def main():
|
||||
# 输入元数据文件夹名称
|
||||
projet_name = '20260319Numbers'
|
||||
projet_name = '20260409 grap'
|
||||
# 请在[]内输入每一个分类的名称
|
||||
# label_names 是一个列表里面按顺序包含了小写的‘a'到‘z’
|
||||
label_names = list(range(10))
|
||||
label_names = list(range(1, 10))
|
||||
hidden_layers = [256, 128, 128, 128]
|
||||
test_size = 0.5
|
||||
dropout_rate = 0
|
||||
epochs = 300
|
||||
|
||||
run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
|
||||
run_project_name = f'{projet_name}/{run_id}'
|
||||
run_dir = Path('Result') / projet_name / run_id
|
||||
run_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
print(label_names)
|
||||
data = load_data(projet_name, label_names)
|
||||
|
||||
model = Qmlp(
|
||||
data=data,
|
||||
labels=label_names,
|
||||
hidden_layers = [128, 256, 128],
|
||||
test_size=0.3,
|
||||
dropout_rate=0
|
||||
hidden_layers=hidden_layers,
|
||||
test_size=test_size,
|
||||
dropout_rate=dropout_rate
|
||||
)
|
||||
# model = QCNN(
|
||||
# data=data,
|
||||
|
|
@ -30,19 +105,71 @@ def main():
|
|||
|
||||
pca_2d, pca_3d = model.get_PCA()
|
||||
|
||||
model.fit(300)
|
||||
model.fit(epochs)
|
||||
|
||||
cm = model.get_cm()
|
||||
cmn = model.get_cmn()
|
||||
epoch_data = model.get_epoch_data()
|
||||
|
||||
save_to_xlsx(project_name=projet_name, file_name="pca_2d", data=pca_2d)
|
||||
save_to_xlsx(project_name=projet_name, file_name="pca_3d", data=pca_3d)
|
||||
save_to_xlsx(project_name=projet_name, file_name="cm", data=cm)
|
||||
save_to_xlsx(project_name=projet_name, file_name="cmn", data=cmn)
|
||||
save_to_xlsx(project_name=projet_name, file_name="acc_and_loss", data=epoch_data)
|
||||
save_to_xlsx(project_name=run_project_name, file_name='pca_2d', data=pca_2d)
|
||||
save_to_xlsx(project_name=run_project_name, file_name='pca_3d', data=pca_3d)
|
||||
save_to_xlsx(project_name=run_project_name, file_name='cm', data=cm)
|
||||
save_to_xlsx(project_name=run_project_name, file_name='cmn', data=cmn)
|
||||
save_to_xlsx(project_name=run_project_name, file_name='acc_and_loss', data=epoch_data)
|
||||
|
||||
print("Done")
|
||||
run_params = {
|
||||
'run_id': run_id,
|
||||
'project_name': projet_name,
|
||||
'label_names': label_names,
|
||||
'model': {
|
||||
'type': 'Qmlp',
|
||||
'hidden_layers': hidden_layers,
|
||||
'dropout_rate': dropout_rate,
|
||||
},
|
||||
'train': {
|
||||
'epochs': epochs,
|
||||
'test_size': test_size,
|
||||
},
|
||||
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
|
||||
}
|
||||
|
||||
if not epoch_data.empty:
|
||||
best_idx = int(epoch_data['test_accuracy'].idxmax())
|
||||
best_row = epoch_data.loc[best_idx].to_dict()
|
||||
last_row = epoch_data.iloc[-1].to_dict()
|
||||
run_result = {
|
||||
'run_id': run_id,
|
||||
'result_dir': str(run_dir),
|
||||
'best_epoch': int(best_row.get('epoch', 0)),
|
||||
'best_metrics': best_row,
|
||||
'last_epoch_metrics': last_row,
|
||||
'generated_files': [
|
||||
'pca_2d.xlsx',
|
||||
'pca_2d.png',
|
||||
'pca_3d.xlsx',
|
||||
'pca_3d.png',
|
||||
'cm.xlsx',
|
||||
'cm.png',
|
||||
'cmn.xlsx',
|
||||
'cmn.png',
|
||||
'acc_and_loss.xlsx',
|
||||
'acc_and_loss_epoch.png',
|
||||
'acc_and_loss_last_epoch_bar.png',
|
||||
],
|
||||
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
|
||||
}
|
||||
else:
|
||||
run_result = {
|
||||
'run_id': run_id,
|
||||
'result_dir': str(run_dir),
|
||||
'message': 'No epoch data generated.',
|
||||
'created_at': datetime.datetime.now().isoformat(timespec='seconds'),
|
||||
}
|
||||
|
||||
_save_yaml(run_dir / 'run_params.yaml', run_params)
|
||||
_save_yaml(run_dir / 'run_result.yaml', run_result)
|
||||
|
||||
print(f'Done. Result saved in: {run_dir}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
|||
Loading…
Reference in New Issue