diff --git a/Qfunctions/loadData.py b/Qfunctions/loadData.py index cad7294..acb146a 100644 --- a/Qfunctions/loadData.py +++ b/Qfunctions/loadData.py @@ -7,7 +7,7 @@ STATIC_PATH = './Static' # 从文件夹中读取所有xlsx文件,每个文件对应一个label # labelNames为label的名字,如果不提供则默认为文件名 -def load_data(folder, labelNames, isDir=True, fileClass='xlsx'): +def load_data(folder, labelNames, fileClass='xlsx'): # 检查folder参数 if folder is None: raise ValueError("The 'folder' parameter is required.") @@ -22,7 +22,12 @@ def load_data(folder, labelNames, isDir=True, fileClass='xlsx'): if not os.path.isdir(folder): raise ValueError(f"The folder '{folder}' does not exist.") - if not isDir: + # 自动检测数据组织方式 + is_dir_mode = _detect_data_mode(folder=folder, labelNames=labelNames, fileClass=fileClass) + mode_name = 'multi-folder mode' if is_dir_mode else 'single-file mode' + print(f"Auto detected data mode: {mode_name}") + + if not is_dir_mode: data = load_from_file(folder=folder, labelNames=labelNames, fileClass=fileClass) else: data = load_from_folder(folder=folder, labelNames=labelNames, fileClass=fileClass) @@ -184,4 +189,46 @@ def _find_matching_file(folder: str, expected_name: str): return None +def _detect_data_mode(folder: str, labelNames, fileClass: str) -> bool: + """Auto detect data organization mode. + + Returns: + True: multi-folder mode (folder/label/*.ext) + False: single-file mode (folder/label.ext) + """ + ext = f'.{fileClass}' + + # 判断是否满足多文件夹模式:每个 label 对应一个子目录,且至少有一个目标后缀文件 + has_all_label_subfolders = True + for label in labelNames: + subfolder = os.path.join(folder, str(label)) + if not (os.path.isdir(subfolder) and any(f.endswith(ext) for f in os.listdir(subfolder))): + has_all_label_subfolders = False + break + + # 判断是否满足单文件模式:每个 label 能匹配到对应文件 + has_all_label_files = True + for label in labelNames: + expected_name = f"{label}.{fileClass}" + if _find_matching_file(folder, expected_name) is None: + has_all_label_files = False + break + + if has_all_label_subfolders and not has_all_label_files: + return True + if has_all_label_files and not has_all_label_subfolders: + return False + + if has_all_label_subfolders and has_all_label_files: + raise ValueError( + "Auto detect found both valid layouts under the same folder. " + "Please keep only one layout type (either subfolders or root files) for each label." + ) + + raise ValueError( + "Auto detect failed: neither single-file nor multi-folder layout matches all labels. " + "Please verify folder structure and labelNames." + ) + + __all__ = ['load_data'] diff --git a/README.md b/README.md index 2cbd55a..2372700 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,6 @@ | 任意值 | 特征值 | 任意值 | 特征值 | |---|---|---|---| | arbitrary value | value | arbitrary value | value | -| arbitrary value | value | arbitrary value | value | ### 1.2 目录约定 训练数据放在 `Static/`,输出结果放在 `Result/`。 @@ -83,7 +82,7 @@ conda list -n Deeplearning --explicit > conda_env/environment.lock.txt ### 3.2 数据目录模板 -单文件模式(`isDir=False`): +单文件模式(每个标签一个文件): ```text Static/ @@ -95,7 +94,7 @@ Static/ Wood.xlsx ``` -多子特征模式(`isDir=True`): +多子特征模式(每个标签一个子目录,目录下可有多个文件): ```text Static/ @@ -119,7 +118,7 @@ Static/ 命名规则(重要): -- `label_names` 中每一项必须与文件名(`isDir=False`)或子文件夹名(`isDir=True`)完全一致(区分大小写)。 +- `label_names` 中每一项必须与文件名(单文件模式)或子文件夹名(多子特征模式)一致。 - `label_names` 顺序就是标签编码顺序,训练结果和混淆矩阵按该顺序展示。 示例: @@ -149,8 +148,10 @@ from Qfunctions.saveToXlsx import save_to_xlsx projet_name = '20241009MaterialDiv' label_names = ['Acrlic', 'Ecoflex', 'PDMS', 'PLA', 'Wood'] -# 读取数据 -data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx') +# 自动识别数据模式 +# - folder/label.xlsx => 单文件模式 +# - folder/label/*.xlsx => 多子特征模式 +data = load_data(projet_name, label_names, fileClass='xlsx') # 划分训练/测试集 X_train, X_test, y_train, y_test, encoder = divSet( @@ -190,9 +191,15 @@ save_to_xlsx(project_name=projet_name, file_name='acc_and_loss', data=epoch_data |---|---|---|---| | folder | str | 必填 | `Static/` 下的数据目录名 | | labelNames | list | 必填 | 类别名称列表,用于读取和排序标签 | -| isDir | bool | True | `False` 对应单文件模式,`True` 对应多子特征模式 | | fileClass | str | xlsx | 数据文件后缀 | +自动识别规则: + +- 若每个 `label` 都对应 `folder/label/*.xlsx`,识别为多子特征模式。 +- 若每个 `label` 都对应 `folder/label.xlsx`,识别为单文件模式。 +- 若两种都成立(同名文件和同名子目录同时存在),会报错并提示只保留一种目录结构。 +- 若两种都不成立,会报错并提示检查目录结构或 `label_names`。 + 读取路径规则: - 单文件模式:`./Static/folder/labelNames[i].xlsx` @@ -204,5 +211,4 @@ save_to_xlsx(project_name=projet_name, file_name='acc_and_loss', data=epoch_data 优先检查: - `label_names` 与文件/文件夹是否同名 -- `isDir` 是否与目录结构匹配 - 文件后缀是否与 `fileClass` 一致 diff --git a/main.py b/main.py index 5df075e..8386be8 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ def main(): # label_names 是一个列表里面按顺序包含了小写的‘a'到‘z’ label_names = list(range(10)) print(label_names) - data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx') + data = load_data(projet_name, label_names, fileClass='xlsx') model = Qmlp( data=data,