This commit is contained in:
qyh1510@gmail.com 2026-01-09 16:09:18 +08:00
parent 4fea502e6b
commit 3e07b4258f
2 changed files with 18 additions and 14 deletions

View File

@ -40,11 +40,11 @@ class Qnn(nn.Module):
def __prepare_data(self): def __prepare_data(self):
# 将data转换为tensor形式 # 将data转换为tensor形式
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32).unsqueeze(1) X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train) self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train)
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long) y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32).unsqueeze(1) X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32)
self.y_test = self.LABEL_ENCODER.transform(self.y_test) self.y_test = self.LABEL_ENCODER.transform(self.y_test)
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long) y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)

28
main.py
View File

@ -1,29 +1,33 @@
from Qtorch.Models.Qcnn import QCNN from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.divSet import divSet from Qfunctions.divSet import divSet
from Qfunctions.loaData import load_data from Qfunctions.loaData import load_data
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
def main(): def main():
# 输入元数据文件夹名称 # 输入元数据文件夹名称
projet_name = '20250623 FHH-write' projet_name = '20251214 WZSX'
# 请在[]内输入每一个分类的名称 # 请在[]内输入每一个分类的名称
label_names = ['5', '2', '0', 'M', 'J', 'U'] label_names = ['canvas', 'lambswool',
'lychee_grain', 'non-woven_fabric', 'nylon',
'PDMS', 'PET', 'PTFE', 'pure_cotton', 'ramie',
'silk_cotton', 'suede'
]
print(label_names) print(label_names)
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx') data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
X_train, X_test, y_train, y_test, encoder = divSet( X_train, X_test, y_train, y_test, encoder = divSet(
data=data, labels=label_names, test_size= 0.3 data=data, labels=label_names, test_size= 0.3
) )
# model = Qmlp( model = Qmlp(
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test, X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
# hidden_layers = [16], hidden_layers = [1024, 512, 256],
# dropout_rate=0 dropout_rate=0
# ) )
model = QCNN( # model = QCNN(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test, # X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
dropout_rate=0 # dropout_rate=0
) # )
pca_2d, pca_3d = model.get_PCA() pca_2d, pca_3d = model.get_PCA()