update
This commit is contained in:
parent
4fea502e6b
commit
3e07b4258f
|
|
@ -40,11 +40,11 @@ class Qnn(nn.Module):
|
||||||
def __prepare_data(self):
|
def __prepare_data(self):
|
||||||
|
|
||||||
# 将data转换为tensor形式
|
# 将data转换为tensor形式
|
||||||
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32).unsqueeze(1)
|
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
|
||||||
self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train)
|
self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train)
|
||||||
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
|
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
|
||||||
|
|
||||||
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32).unsqueeze(1)
|
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32)
|
||||||
self.y_test = self.LABEL_ENCODER.transform(self.y_test)
|
self.y_test = self.LABEL_ENCODER.transform(self.y_test)
|
||||||
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)
|
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)
|
||||||
|
|
||||||
|
|
|
||||||
24
main.py
24
main.py
|
|
@ -1,30 +1,34 @@
|
||||||
from Qtorch.Models.Qcnn import QCNN
|
from Qtorch.Models.Qmlp import Qmlp
|
||||||
from Qfunctions.divSet import divSet
|
from Qfunctions.divSet import divSet
|
||||||
from Qfunctions.loaData import load_data
|
from Qfunctions.loaData import load_data
|
||||||
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
|
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 输入元数据文件夹名称
|
# 输入元数据文件夹名称
|
||||||
projet_name = '20250623 FHH-write'
|
projet_name = '20251214 WZSX'
|
||||||
# 请在[]内输入每一个分类的名称
|
# 请在[]内输入每一个分类的名称
|
||||||
label_names = ['5', '2', '0', 'M', 'J', 'U']
|
label_names = ['canvas', 'lambswool',
|
||||||
|
'lychee_grain', 'non-woven_fabric', 'nylon',
|
||||||
|
'PDMS', 'PET', 'PTFE', 'pure_cotton', 'ramie',
|
||||||
|
'silk_cotton', 'suede'
|
||||||
|
]
|
||||||
print(label_names)
|
print(label_names)
|
||||||
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
|
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
|
||||||
X_train, X_test, y_train, y_test, encoder = divSet(
|
X_train, X_test, y_train, y_test, encoder = divSet(
|
||||||
data=data, labels=label_names, test_size= 0.3
|
data=data, labels=label_names, test_size= 0.3
|
||||||
)
|
)
|
||||||
|
|
||||||
# model = Qmlp(
|
model = Qmlp(
|
||||||
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
|
||||||
# hidden_layers = [16],
|
|
||||||
# dropout_rate=0
|
|
||||||
# )
|
|
||||||
|
|
||||||
model = QCNN(
|
|
||||||
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
||||||
|
hidden_layers = [1024, 512, 256],
|
||||||
dropout_rate=0
|
dropout_rate=0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# model = QCNN(
|
||||||
|
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
||||||
|
# dropout_rate=0
|
||||||
|
# )
|
||||||
|
|
||||||
pca_2d, pca_3d = model.get_PCA()
|
pca_2d, pca_3d = model.get_PCA()
|
||||||
|
|
||||||
model.fit(300)
|
model.fit(300)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue