diff --git a/README.md b/README.md index 175f1f8..f44a087 100644 --- a/README.md +++ b/README.md @@ -252,6 +252,19 @@ save_to_xlsx(project_name=projet_name, file_name='cmn', data=cmn) save_to_xlsx(project_name=projet_name, file_name='acc_and_loss', data=epoch_data) ``` +### 3.6 每次运行参数与结果 YAML 存档 + +当前 `main.py` 会为每次运行创建时间戳目录: + +- `Result///` + +并在目录下自动保存: + +- `run_params.yaml`:本次运行参数快照。 +- `run_result.yaml`:本次运行结果摘要(最佳轮次与最后轮次指标)。 + +这样可以直接对比不同运行的参数与结果变化。 + ## 4. load_data 参数说明 | 参数 | 类型 | 默认值 | 说明 | diff --git a/main.py b/main.py index 36eecba..3dd0c0e 100644 --- a/main.py +++ b/main.py @@ -1,22 +1,97 @@ +import datetime +from pathlib import Path + from Qtorch.Models.Qmlp import Qmlp from Qfunctions.loadData import load_data from Qfunctions.saveToXlsx import save_to_xlsx as save_to_xlsx + +def _to_builtin(value): + if isinstance(value, dict): + return {str(k): _to_builtin(v) for k, v in value.items()} + if isinstance(value, (list, tuple)): + return [_to_builtin(v) for v in value] + if hasattr(value, 'item'): + try: + return value.item() + except Exception: + return str(value) + return value + + +def _yaml_scalar(value): + if value is None: + return 'null' + if isinstance(value, bool): + return 'true' if value else 'false' + if isinstance(value, (int, float)): + return str(value) + text = str(value).replace('"', '\\"') + return f'"{text}"' + + +def _yaml_lines(value, indent=0): + space = ' ' * indent + + if isinstance(value, dict): + if not value: + return [space + '{}'] + lines = [] + for k, v in value.items(): + key = str(k) + if isinstance(v, (dict, list, tuple)): + lines.append(f'{space}{key}:') + lines.extend(_yaml_lines(v, indent + 2)) + else: + lines.append(f'{space}{key}: {_yaml_scalar(v)}') + return lines + + if isinstance(value, (list, tuple)): + if not value: + return [space + '[]'] + lines = [] + for item in value: + if isinstance(item, (dict, list, tuple)): + lines.append(f'{space}-') + lines.extend(_yaml_lines(item, indent + 2)) + else: + lines.append(f'{space}- {_yaml_scalar(item)}') + return lines + + return [space + _yaml_scalar(value)] + + +def _save_yaml(file_path, data): + built_data = _to_builtin(data) + text = '\n'.join(_yaml_lines(built_data)) + '\n' + with open(file_path, 'w', encoding='utf-8') as f: + f.write(text) + def main(): # 输入元数据文件夹名称 - projet_name = '20260319Numbers' + projet_name = '20260409 grap' # 请在[]内输入每一个分类的名称 # label_names 是一个列表里面按顺序包含了小写的‘a'到‘z’ - label_names = list(range(10)) + label_names = list(range(1, 10)) + hidden_layers = [256, 128, 128, 128] + test_size = 0.5 + dropout_rate = 0 + epochs = 300 + + run_id = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + run_project_name = f'{projet_name}/{run_id}' + run_dir = Path('Result') / projet_name / run_id + run_dir.mkdir(parents=True, exist_ok=True) + print(label_names) data = load_data(projet_name, label_names) model = Qmlp( data=data, labels=label_names, - hidden_layers = [128, 256, 128], - test_size=0.3, - dropout_rate=0 + hidden_layers=hidden_layers, + test_size=test_size, + dropout_rate=dropout_rate ) # model = QCNN( # data=data, @@ -30,19 +105,71 @@ def main(): pca_2d, pca_3d = model.get_PCA() - model.fit(300) + model.fit(epochs) cm = model.get_cm() cmn = model.get_cmn() epoch_data = model.get_epoch_data() - save_to_xlsx(project_name=projet_name, file_name="pca_2d", data=pca_2d) - save_to_xlsx(project_name=projet_name, file_name="pca_3d", data=pca_3d) - save_to_xlsx(project_name=projet_name, file_name="cm", data=cm) - save_to_xlsx(project_name=projet_name, file_name="cmn", data=cmn) - save_to_xlsx(project_name=projet_name, file_name="acc_and_loss", data=epoch_data) + save_to_xlsx(project_name=run_project_name, file_name='pca_2d', data=pca_2d) + save_to_xlsx(project_name=run_project_name, file_name='pca_3d', data=pca_3d) + save_to_xlsx(project_name=run_project_name, file_name='cm', data=cm) + save_to_xlsx(project_name=run_project_name, file_name='cmn', data=cmn) + save_to_xlsx(project_name=run_project_name, file_name='acc_and_loss', data=epoch_data) - print("Done") + run_params = { + 'run_id': run_id, + 'project_name': projet_name, + 'label_names': label_names, + 'model': { + 'type': 'Qmlp', + 'hidden_layers': hidden_layers, + 'dropout_rate': dropout_rate, + }, + 'train': { + 'epochs': epochs, + 'test_size': test_size, + }, + 'created_at': datetime.datetime.now().isoformat(timespec='seconds'), + } + + if not epoch_data.empty: + best_idx = int(epoch_data['test_accuracy'].idxmax()) + best_row = epoch_data.loc[best_idx].to_dict() + last_row = epoch_data.iloc[-1].to_dict() + run_result = { + 'run_id': run_id, + 'result_dir': str(run_dir), + 'best_epoch': int(best_row.get('epoch', 0)), + 'best_metrics': best_row, + 'last_epoch_metrics': last_row, + 'generated_files': [ + 'pca_2d.xlsx', + 'pca_2d.png', + 'pca_3d.xlsx', + 'pca_3d.png', + 'cm.xlsx', + 'cm.png', + 'cmn.xlsx', + 'cmn.png', + 'acc_and_loss.xlsx', + 'acc_and_loss_epoch.png', + 'acc_and_loss_last_epoch_bar.png', + ], + 'created_at': datetime.datetime.now().isoformat(timespec='seconds'), + } + else: + run_result = { + 'run_id': run_id, + 'result_dir': str(run_dir), + 'message': 'No epoch data generated.', + 'created_at': datetime.datetime.now().isoformat(timespec='seconds'), + } + + _save_yaml(run_dir / 'run_params.yaml', run_params) + _save_yaml(run_dir / 'run_result.yaml', run_result) + + print(f'Done. Result saved in: {run_dir}') if __name__ == '__main__': main()