Deeplearning/main.py

55 lines
1.7 KiB
Python
Raw Permalink Normal View History

2024-10-07 09:54:32 +08:00
# frofrom Qtorch.Functions import dLoader
from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.divSet import divSet
from Qfunctions.loaData import load_data as dLoader
from sklearn.decomposition import PCA
def main():
2024-10-19 11:07:59 +08:00
projet_name = '20241009MaterialDiv'
label_names =["Acrylic", "Ecoflex", "PDMS", "PLA", "Wood"]
data = dLoader(projet_name, label_names, isDir=True)
2024-10-07 09:54:32 +08:00
X_train, X_test, y_train, y_test, encoder = divSet(
2024-10-19 11:07:59 +08:00
data=data, labels=label_names, test_size= 0.3
2024-10-07 09:54:32 +08:00
)
print(y_train)
import pandas as pd
pca = PCA(n_components=2) # 保留两个主成分
principalComponents = pca.fit_transform(X_train)
df_pca2d = pd.DataFrame(data=principalComponents, columns=['PC1', 'PC2'])
df_pca2d['labels'] = y_train
pca = PCA(n_components=3) # 保留三个主成分
principalComponents = pca.fit_transform(X_train)
df_pca3d = pd.DataFrame(data=principalComponents, columns=['PC1', 'PC2', 'PC3'])
df_pca3d['labels'] = y_train
2024-10-19 11:07:59 +08:00
# 保存为xlsx文件
2024-10-07 09:54:32 +08:00
import os
folder = os.path.join("./Result", projet_name)
2024-10-19 11:07:59 +08:00
if not os.path.exists(folder):
os.makedirs(folder)
2024-10-07 09:54:32 +08:00
df_pca2d.to_excel(os.path.join(folder, 'pca_2d_points_with_labels.xlsx'), index=False)
df_pca3d.to_excel(os.path.join(folder, 'pca_3d_points_with_labels.xlsx'), index=False)
2024-10-07 10:06:20 +08:00
model = Qmlp(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
2024-10-19 11:07:59 +08:00
hidden_layers=[128, 128],
2024-10-07 10:06:20 +08:00
dropout_rate=0
)
2024-10-19 11:07:59 +08:00
model.fit(300)
2024-10-07 09:54:32 +08:00
2024-10-07 10:06:20 +08:00
cm = model.get_cm()
epoch_data = model.get_epoch_data()
2024-10-19 11:07:59 +08:00
2024-10-07 10:06:20 +08:00
from Qfunctions.saveToxlsx import save_to_xlsx as stx
2024-10-19 11:07:59 +08:00
stx(project_name=projet_name, file_name="cm", data=cm )
2024-10-07 10:06:20 +08:00
stx(project_name=projet_name, file_name="acc_and_loss", data=epoch_data)
2024-10-19 11:07:59 +08:00
2024-10-07 10:06:20 +08:00
print("Done")
2024-10-07 09:54:32 +08:00
if __name__ == '__main__':
main()