This commit is contained in:
qyh1510@gmail.com 2026-01-09 16:09:18 +08:00
parent 4fea502e6b
commit 3e07b4258f
2 changed files with 18 additions and 14 deletions

View File

@ -40,11 +40,11 @@ class Qnn(nn.Module):
def __prepare_data(self):
# 将data转换为tensor形式
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32).unsqueeze(1)
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train)
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32).unsqueeze(1)
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32)
self.y_test = self.LABEL_ENCODER.transform(self.y_test)
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)

28
main.py
View File

@ -1,29 +1,33 @@
from Qtorch.Models.Qcnn import QCNN
from Qtorch.Models.Qmlp import Qmlp
from Qfunctions.divSet import divSet
from Qfunctions.loaData import load_data
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
def main():
# 输入元数据文件夹名称
projet_name = '20250623 FHH-write'
projet_name = '20251214 WZSX'
# 请在[]内输入每一个分类的名称
label_names = ['5', '2', '0', 'M', 'J', 'U']
label_names = ['canvas', 'lambswool',
'lychee_grain', 'non-woven_fabric', 'nylon',
'PDMS', 'PET', 'PTFE', 'pure_cotton', 'ramie',
'silk_cotton', 'suede'
]
print(label_names)
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
X_train, X_test, y_train, y_test, encoder = divSet(
data=data, labels=label_names, test_size= 0.3
)
# model = Qmlp(
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
# hidden_layers = [16],
# dropout_rate=0
# )
model = Qmlp(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
hidden_layers = [1024, 512, 256],
dropout_rate=0
)
model = QCNN(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
dropout_rate=0
)
# model = QCNN(
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
# dropout_rate=0
# )
pca_2d, pca_3d = model.get_PCA()