47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
from Qtorch.Models.Qmlp import Qmlp
|
||
from Qfunctions.divSet import divSet
|
||
from Qfunctions.loaData import load_data
|
||
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
|
||
|
||
def main():
|
||
# 输入元数据文件夹名称
|
||
projet_name = '20260318Letters'
|
||
# 请在[]内输入每一个分类的名称
|
||
# label_names 是一个列表里面按顺序包含了小写的‘a'到‘z’
|
||
label_names = ['a','b']
|
||
print(label_names)
|
||
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
|
||
X_train, X_test, y_train, y_test, encoder = divSet(
|
||
data=data, labels=label_names, test_size= 0.3
|
||
)
|
||
|
||
model = Qmlp(
|
||
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
||
hidden_layers = [64],
|
||
dropout_rate=0
|
||
)
|
||
|
||
# model = QCNN(
|
||
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
||
# dropout_rate=0
|
||
# )
|
||
|
||
pca_2d, pca_3d = model.get_PCA()
|
||
|
||
model.fit(300)
|
||
|
||
cm = model.get_cm()
|
||
cmn = model.get_cmn()
|
||
epoch_data = model.get_epoch_data()
|
||
|
||
save_to_xlsx(project_name=projet_name, file_name="pca_2d", data=pca_2d)
|
||
save_to_xlsx(project_name=projet_name, file_name="pca_3d", data=pca_3d)
|
||
save_to_xlsx(project_name=projet_name, file_name="cm", data=cm)
|
||
save_to_xlsx(project_name=projet_name, file_name="cmn", data=cmn)
|
||
save_to_xlsx(project_name=projet_name, file_name="acc_and_loss", data=epoch_data)
|
||
|
||
print("Done")
|
||
|
||
if __name__ == '__main__':
|
||
main()
|