remove unuse conde
This commit is contained in:
parent
6e99f6caa8
commit
0dd4b05977
|
|
@ -60,7 +60,6 @@ class Qmlp(nn.Module):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __prepare_data(self):
|
def __prepare_data(self):
|
||||||
|
|
||||||
# Step 2: Prepare the data
|
# Step 2: Prepare the data
|
||||||
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
|
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
|
||||||
self.y_train = LABEL_ENCODER.fit_transform(self.y_train)
|
self.y_train = LABEL_ENCODER.fit_transform(self.y_train)
|
||||||
|
|
@ -151,14 +150,7 @@ class Qmlp(nn.Module):
|
||||||
print(f"Early stopping at epoch {epoch+1}")
|
print(f"Early stopping at epoch {epoch+1}")
|
||||||
break
|
break
|
||||||
|
|
||||||
if self.labels:
|
|
||||||
# labels_encoded = LABEL_ENCODER.fit(self.labels)
|
|
||||||
self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
|
self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
|
||||||
else:
|
|
||||||
self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
|
|
||||||
|
|
||||||
|
|
||||||
# self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
|
|
||||||
print(self.cm)
|
print(self.cm)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -169,10 +161,10 @@ class Qmlp(nn.Module):
|
||||||
return pd.DataFrame(self.epoch_data)
|
return pd.DataFrame(self.epoch_data)
|
||||||
|
|
||||||
def fit(self, epoch_times = 100):
|
def fit(self, epoch_times = 100):
|
||||||
|
|
||||||
train_loader, test_loader = self.__prepare_data()
|
train_loader, test_loader = self.__prepare_data()
|
||||||
self.__train_model(train_loader, test_loader, epochs_times=epoch_times)
|
self.__train_model(train_loader, test_loader, epochs_times=epoch_times)
|
||||||
return
|
return
|
||||||
|
|
||||||
def __init_weights(self):
|
def __init_weights(self):
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Linear):
|
if isinstance(m, nn.Linear):
|
||||||
|
|
|
||||||
|
|
@ -12,29 +12,14 @@ from Qfunctions.saveToxlsx import save_to_xlsx as stx
|
||||||
|
|
||||||
class Qnn(nn.Module, ABC):
|
class Qnn(nn.Module, ABC):
|
||||||
|
|
||||||
def __init__(self, data, labels=None, test_size=0.2, random_state=None):
|
def __init__(self, labels=None):
|
||||||
super(Qnn, self).__init__()
|
super(Qnn, self).__init__()
|
||||||
|
|
||||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
# 保存原始labe, 混淆矩阵使用
|
# 保存原始label, 混淆矩阵使用
|
||||||
self.original_labels = labels
|
self.original_labels = labels
|
||||||
|
|
||||||
# 划分训练集和测试集
|
|
||||||
X_train, X_test, y_train, y_test, self.labels = ds(
|
|
||||||
data=data,
|
|
||||||
labels=labels,
|
|
||||||
test_size=test_size,
|
|
||||||
random_state=random_state
|
|
||||||
)
|
|
||||||
|
|
||||||
self.train_loader, self.test_loader = self.__prepare_data(
|
|
||||||
X_train=X_train,
|
|
||||||
y_train=y_train,
|
|
||||||
X_test=X_test,
|
|
||||||
y_test=y_test
|
|
||||||
)
|
|
||||||
|
|
||||||
# 定义结果
|
# 定义结果
|
||||||
self.result = {
|
self.result = {
|
||||||
'acc_and_loss' : {
|
'acc_and_loss' : {
|
||||||
|
|
@ -53,10 +38,6 @@ class Qnn(nn.Module, ABC):
|
||||||
def hinge_loss(self, output, target):
|
def hinge_loss(self, output, target):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def train_model(self, train_loader, test_loader, epochs):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def confusion_matrix(self, test_outputs):
|
def confusion_matrix(self, test_outputs):
|
||||||
predicted = torch.argmax(test_outputs, dim=1)
|
predicted = torch.argmax(test_outputs, dim=1)
|
||||||
true_label = torch.argmax(self.y_test, dim=1)
|
true_label = torch.argmax(self.y_test, dim=1)
|
||||||
|
|
@ -64,42 +45,3 @@ class Qnn(nn.Module, ABC):
|
||||||
|
|
||||||
def fit(self, epochs = 100):
|
def fit(self, epochs = 100):
|
||||||
self.train_model(epochs)
|
self.train_model(epochs)
|
||||||
|
|
||||||
def save(self, project_name):
|
|
||||||
for filename, data in self.result.items():
|
|
||||||
if filename == 'confusion_matrix':
|
|
||||||
data = pd.DataFrame(data, columns=self.original_labels, index=self.original_labels)
|
|
||||||
stx(project_name, filename, data)
|
|
||||||
else:
|
|
||||||
data = pd.DataFrame(data)
|
|
||||||
stx(project_name, filename, data)
|
|
||||||
|
|
||||||
def __prepare_data(self, X_train, y_train, X_test, y_test):
|
|
||||||
|
|
||||||
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
|
|
||||||
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
|
|
||||||
|
|
||||||
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
|
|
||||||
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
|
|
||||||
|
|
||||||
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
|
|
||||||
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
|
|
||||||
|
|
||||||
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
|
|
||||||
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
|
|
||||||
|
|
||||||
print(train_loader, test_loader)
|
|
||||||
|
|
||||||
return train_loader, test_loader
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
6
main.py
6
main.py
|
|
@ -5,9 +5,9 @@ from Qfunctions.loaData import load_data as dLoader
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
projet_name = '###########' # 输入元数据文件夹名称
|
projet_name = '20241112Numbers' # 输入元数据文件夹名称
|
||||||
label_names =[] # 请在[]内输入每一个分类的名称
|
label_names =['1', '2', '3', '4', '5', '6', '7' ,'8', '9'] # 请在[]内输入每一个分类的名称
|
||||||
data = dLoader(projet_name, label_names, isDir=False)
|
data = dLoader(projet_name, label_names, isDir=False, fileClass='xls')
|
||||||
X_train, X_test, y_train, y_test, encoder = divSet(
|
X_train, X_test, y_train, y_test, encoder = divSet(
|
||||||
data=data, labels=label_names, test_size= 0.3
|
data=data, labels=label_names, test_size= 0.3
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue