Update main.py to change project name and simplify label_names; adjust hidden_layers in Qmlp model
This commit is contained in:
parent
5da38b49b0
commit
e2dc549be9
|
|
@ -83,9 +83,9 @@ def load_xlsx(fileName, labelName, max_row_length = 1000, fill_rule = None):
|
|||
|
||||
# 提取偶数列
|
||||
features = df.iloc[0:, 1::2]
|
||||
# 复制 features DataFrame
|
||||
# ## 复制 features DataFrame
|
||||
# features_copy = features.copy()
|
||||
# 使用 pd.concat 来追加副本到原始 DataFrame
|
||||
# ## 使用 pd.concat 来追加副本到原始 DataFrame
|
||||
# features = pd.concat([features, features_copy], ignore_index=True, axis=1)
|
||||
|
||||
# 计算变化率
|
||||
|
|
|
|||
10
main.py
10
main.py
|
|
@ -5,12 +5,10 @@ from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
|
|||
|
||||
def main():
|
||||
# 输入元数据文件夹名称
|
||||
projet_name = '20260109WZSX'
|
||||
projet_name = '20260318Letters'
|
||||
# 请在[]内输入每一个分类的名称
|
||||
label_names = [
|
||||
'Crocodile grain', 'Litchi grain','Pin grain',
|
||||
'Mohair tweed', 'Polar fleece', 'Berber fleece'
|
||||
]
|
||||
# label_names 是一个列表里面按顺序包含了小写的‘a'到‘z’
|
||||
label_names = ['a','b']
|
||||
print(label_names)
|
||||
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
|
||||
X_train, X_test, y_train, y_test, encoder = divSet(
|
||||
|
|
@ -19,7 +17,7 @@ def main():
|
|||
|
||||
model = Qmlp(
|
||||
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
||||
hidden_layers = [256, 256, 256],
|
||||
hidden_layers = [64],
|
||||
dropout_rate=0
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue