Deeplearning/main.py

46 lines
1.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.divSet import divSet
from Qfunctions.loaData import load_data
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
def main():
# 输入元数据文件夹名称
projet_name = '20260319Numbers'
# 请在[]内输入每一个分类的名称
# label_names 是一个列表里面按顺序包含了小写的a'到z
label_names = list(range(10))
print(label_names)
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
X_train, X_test, y_train, y_test, encoder = divSet(
data=data, labels=label_names, test_size= 0.3
)
model = Qmlp(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
hidden_layers = [128, 256, 128],
dropout_rate=0
)
# model = QCNN(
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
# dropout_rate=0
# )
pca_2d, pca_3d = model.get_PCA()
model.fit(300)
cm = model.get_cm()
cmn = model.get_cmn()
epoch_data = model.get_epoch_data()
save_to_xlsx(project_name=projet_name, file_name="pca_2d", data=pca_2d)
save_to_xlsx(project_name=projet_name, file_name="pca_3d", data=pca_3d)
save_to_xlsx(project_name=projet_name, file_name="cm", data=cm)
save_to_xlsx(project_name=projet_name, file_name="cmn", data=cmn)
save_to_xlsx(project_name=projet_name, file_name="acc_and_loss", data=epoch_data)
print("Done")
if __name__ == '__main__':
main()