Deeplearning/main.py

49 lines
1.5 KiB
Python

from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.divSet import divSet
from Qfunctions.loaData import load_data
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
def main():
# 输入元数据文件夹名称
projet_name = '20260109WZSX'
# 请在[]内输入每一个分类的名称
label_names = [
'Crocodile grain', 'Litchi grain','Pin grain',
'Mohair tweed', 'Polar fleece', 'Berber fleece'
]
print(label_names)
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
X_train, X_test, y_train, y_test, encoder = divSet(
data=data, labels=label_names, test_size= 0.3
)
model = Qmlp(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
hidden_layers = [256, 256, 256],
dropout_rate=0
)
# model = QCNN(
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
# dropout_rate=0
# )
pca_2d, pca_3d = model.get_PCA()
model.fit(300)
cm = model.get_cm()
cmn = model.get_cmn()
epoch_data = model.get_epoch_data()
save_to_xlsx(project_name=projet_name, file_name="pca_2d", data=pca_2d)
save_to_xlsx(project_name=projet_name, file_name="pca_3d", data=pca_3d)
save_to_xlsx(project_name=projet_name, file_name="cm", data=cm)
save_to_xlsx(project_name=projet_name, file_name="cmn", data=cmn)
save_to_xlsx(project_name=projet_name, file_name="acc_and_loss", data=epoch_data)
print("Done")
if __name__ == '__main__':
main()