diff --git a/Qfunctions/loadData.py b/Qfunctions/loadData.py index acb146a..1a83b9b 100644 --- a/Qfunctions/loadData.py +++ b/Qfunctions/loadData.py @@ -3,11 +3,12 @@ import unicodedata import pandas as pd STATIC_PATH = './Static' +DEFAULT_FILE_CLASSES = ('xlsx', 'xls', 'csv') -# 从文件夹中读取所有xlsx文件,每个文件对应一个label +# 从文件夹中读取所有数据文件,支持 xls/xlsx/csv # labelNames为label的名字,如果不提供则默认为文件名 -def load_data(folder, labelNames, fileClass='xlsx'): +def load_data(folder, labelNames): # 检查folder参数 if folder is None: raise ValueError("The 'folder' parameter is required.") @@ -22,27 +23,28 @@ def load_data(folder, labelNames, fileClass='xlsx'): if not os.path.isdir(folder): raise ValueError(f"The folder '{folder}' does not exist.") + file_classes = DEFAULT_FILE_CLASSES + # 自动检测数据组织方式 - is_dir_mode = _detect_data_mode(folder=folder, labelNames=labelNames, fileClass=fileClass) + is_dir_mode = _detect_data_mode(folder=folder, labelNames=labelNames, fileClasses=file_classes) mode_name = 'multi-folder mode' if is_dir_mode else 'single-file mode' print(f"Auto detected data mode: {mode_name}") if not is_dir_mode: - data = load_from_file(folder=folder, labelNames=labelNames, fileClass=fileClass) + data = load_from_file(folder=folder, labelNames=labelNames, fileClasses=file_classes) else: - data = load_from_folder(folder=folder, labelNames=labelNames, fileClass=fileClass) + data = load_from_folder(folder=folder, labelNames=labelNames, fileClasses=file_classes) print(data) return data -def load_from_folder(folder, labelNames, fileClass): +def load_from_folder(folder, labelNames, fileClasses): all_features = [] - fileClass = '.' + fileClass for labelName in labelNames: - subfolder = os.path.join(folder, labelName) + subfolder = os.path.join(folder, str(labelName)) if os.path.exists(subfolder) and os.path.isdir(subfolder): - fileNames = [f for f in os.listdir(subfolder) if f.endswith(fileClass)] + fileNames = [f for f in os.listdir(subfolder) if _has_supported_extension(f, fileClasses)] max_row_length = get_max_row_len(subfolder, fileNames) features = [] for fileName in fileNames: @@ -55,17 +57,15 @@ def load_from_folder(folder, labelNames, fileClass): return pd.concat(all_features, ignore_index=True) -def load_from_file(folder, labelNames, fileClass): +def load_from_file(folder, labelNames, fileClasses): # 构建期望的文件名(label + .扩展名),并在目录中进行健壮匹配 # (去除零宽字符、Unicode 规范化、大小写不敏感) - expected_names = [f"{labelName}.{fileClass}" for labelName in labelNames] - actual_file_names = [] missing = [] - for expected in expected_names: - match = _find_matching_file(folder, expected) + for labelName in labelNames: + match = _find_matching_file_by_label(folder, labelName, fileClasses) if match is None: - missing.append(expected) + missing.append(f"{labelName}.<{'/'.join(fileClasses)}>") else: actual_file_names.append(match) @@ -89,7 +89,7 @@ def load_from_file(folder, labelNames, fileClass): def load_xlsx(fileName, labelName, max_row_length=1000, fill_rule=None): - df = pd.read_excel(fileName) + df = _read_data_file(fileName) # 提取偶数列 features = df.iloc[0:, 1::2] @@ -128,11 +128,31 @@ def fill_to_len(row, length=1000, rule=None): def get_max_row_len(folder, filenames): max_len = 0 for filename in filenames: - df = pd.read_excel(os.path.join(folder, filename)) + df = _read_data_file(os.path.join(folder, filename)) max_len = max(max_len, df.shape[0]) return max_len +def _read_data_file(file_path: str): + ext = os.path.splitext(file_path)[1].lower() + if ext == '.csv': + return pd.read_csv(file_path) + if ext in ('.xls', '.xlsx'): + return pd.read_excel(file_path) + raise ValueError( + f"Unsupported file format: {ext}. Only .xls, .xlsx, and .csv are supported. " + f"File: {file_path}" + ) + + + + + +def _has_supported_extension(filename: str, fileClasses) -> bool: + ext = os.path.splitext(filename)[1].lower().lstrip('.') + return ext in fileClasses + + # ---------- 内部工具函数:处理包含零宽字符或不同 Unicode 形式的文件名匹配 ---------- def _strip_zero_width(s: str) -> str: @@ -189,28 +209,34 @@ def _find_matching_file(folder: str, expected_name: str): return None -def _detect_data_mode(folder: str, labelNames, fileClass: str) -> bool: +def _find_matching_file_by_label(folder: str, label_name, fileClasses): + for ext in fileClasses: + expected_name = f"{label_name}.{ext}" + match = _find_matching_file(folder, expected_name) + if match is not None: + return match + return None + + +def _detect_data_mode(folder: str, labelNames, fileClasses) -> bool: """Auto detect data organization mode. Returns: True: multi-folder mode (folder/label/*.ext) False: single-file mode (folder/label.ext) """ - ext = f'.{fileClass}' - # 判断是否满足多文件夹模式:每个 label 对应一个子目录,且至少有一个目标后缀文件 has_all_label_subfolders = True for label in labelNames: subfolder = os.path.join(folder, str(label)) - if not (os.path.isdir(subfolder) and any(f.endswith(ext) for f in os.listdir(subfolder))): + if not (os.path.isdir(subfolder) and any(_has_supported_extension(f, fileClasses) for f in os.listdir(subfolder))): has_all_label_subfolders = False break # 判断是否满足单文件模式:每个 label 能匹配到对应文件 has_all_label_files = True for label in labelNames: - expected_name = f"{label}.{fileClass}" - if _find_matching_file(folder, expected_name) is None: + if _find_matching_file_by_label(folder, label, fileClasses) is None: has_all_label_files = False break diff --git a/README.md b/README.md index 2372700..dfec60c 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ ## 1. 项目约定 ### 1.1 输入数据格式 -每一类数据建议保存为 `xlsx/xls`。读取时默认取偶数列(索引 1,3,5...)作为特征,奇数列内容可忽略。 +每一类数据支持 `xls/xlsx/csv`。读取时默认取偶数列(索引 1,3,5...)作为特征,奇数列内容可忽略。 示意: @@ -148,10 +148,11 @@ from Qfunctions.saveToXlsx import save_to_xlsx projet_name = '20241009MaterialDiv' label_names = ['Acrlic', 'Ecoflex', 'PDMS', 'PLA', 'Wood'] -# 自动识别数据模式 -# - folder/label.xlsx => 单文件模式 -# - folder/label/*.xlsx => 多子特征模式 -data = load_data(projet_name, label_names, fileClass='xlsx') +# 自动识別数据模式 +# 支持 .xls 、.xlsx 、.csv 三种格式(可混合使用) +# - folder/label.xlsx 或 folder/label.xls 或 folder/label.csv => 单文件模式 +# - folder/label/*.(xlsx|xls|csv) => 多子特征模式 +data = load_data(projet_name, label_names) # 划分训练/测试集 X_train, X_test, y_train, y_test, encoder = divSet( @@ -191,19 +192,17 @@ save_to_xlsx(project_name=projet_name, file_name='acc_and_loss', data=epoch_data |---|---|---|---| | folder | str | 必填 | `Static/` 下的数据目录名 | | labelNames | list | 必填 | 类别名称列表,用于读取和排序标签 | -| fileClass | str | xlsx | 数据文件后缀 | 自动识别规则: -- 若每个 `label` 都对应 `folder/label/*.xlsx`,识别为多子特征模式。 -- 若每个 `label` 都对应 `folder/label.xlsx`,识别为单文件模式。 -- 若两种都成立(同名文件和同名子目录同时存在),会报错并提示只保留一种目录结构。 +- 若每个 `label` 都对应 `folder/label/*.(xlsx|xls|csv)`,识别为多子特征模式。 +- 若每个 `label` 都对应 `folder/label.(xlsx|xls|csv)`,识别为单文件模式。- 超出需法的文件格式(只许 xls/xlsx/csv),汽转时报错。- 若两种都成立(同名文件和同名子目录同时存在),会报错并提示只保留一种目录结构。 - 若两种都不成立,会报错并提示检查目录结构或 `label_names`。 读取路径规则: -- 单文件模式:`./Static/folder/labelNames[i].xlsx` -- 多子特征模式:`./Static/folder/labelNames[i]/*.xlsx` +- 单文件模式:`./Static/folder/labelNames[i].(xlsx|xls|csv)` +- 多子特征模式:`./Static/folder/labelNames[i]/*.(xlsx|xls|csv)` ## 5. 常见问题 @@ -211,4 +210,4 @@ save_to_xlsx(project_name=projet_name, file_name='acc_and_loss', data=epoch_data 优先检查: - `label_names` 与文件/文件夹是否同名 -- 文件后缀是否与 `fileClass` 一致 +- 文件后缀是否为 `.xls`、`.xlsx` 或 `.csv`(其他格式将报错) diff --git a/main.py b/main.py index 8386be8..007ca4b 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ def main(): # label_names 是一个列表里面按顺序包含了小写的‘a'到‘z’ label_names = list(range(10)) print(label_names) - data = load_data(projet_name, label_names, fileClass='xlsx') + data = load_data(projet_name, label_names) model = Qmlp( data=data,