remove unuse conde

This commit is contained in:
newbie 2024-11-28 13:42:32 +08:00
parent 6e99f6caa8
commit 0dd4b05977
3 changed files with 8 additions and 74 deletions

View File

@ -60,7 +60,6 @@ class Qmlp(nn.Module):
return x
def __prepare_data(self):
# Step 2: Prepare the data
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
self.y_train = LABEL_ENCODER.fit_transform(self.y_train)
@ -151,14 +150,7 @@ class Qmlp(nn.Module):
print(f"Early stopping at epoch {epoch+1}")
break
if self.labels:
# labels_encoded = LABEL_ENCODER.fit(self.labels)
self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
else:
self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
# self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
self.cm = confusion_matrix(all_labels, all_predicted, normalize='true')
print(self.cm)
return
@ -169,10 +161,10 @@ class Qmlp(nn.Module):
return pd.DataFrame(self.epoch_data)
def fit(self, epoch_times = 100):
train_loader, test_loader = self.__prepare_data()
self.__train_model(train_loader, test_loader, epochs_times=epoch_times)
return
def __init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):

View File

@ -12,29 +12,14 @@ from Qfunctions.saveToxlsx import save_to_xlsx as stx
class Qnn(nn.Module, ABC):
def __init__(self, data, labels=None, test_size=0.2, random_state=None):
def __init__(self, labels=None):
super(Qnn, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 保存原始labe 混淆矩阵使用
# 保存原始label 混淆矩阵使用
self.original_labels = labels
# 划分训练集和测试集
X_train, X_test, y_train, y_test, self.labels = ds(
data=data,
labels=labels,
test_size=test_size,
random_state=random_state
)
self.train_loader, self.test_loader = self.__prepare_data(
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test
)
# 定义结果
self.result = {
'acc_and_loss' : {
@ -53,10 +38,6 @@ class Qnn(nn.Module, ABC):
def hinge_loss(self, output, target):
pass
@abstractmethod
def train_model(self, train_loader, test_loader, epochs):
pass
def confusion_matrix(self, test_outputs):
predicted = torch.argmax(test_outputs, dim=1)
true_label = torch.argmax(self.y_test, dim=1)
@ -64,42 +45,3 @@ class Qnn(nn.Module, ABC):
def fit(self, epochs = 100):
self.train_model(epochs)
def save(self, project_name):
for filename, data in self.result.items():
if filename == 'confusion_matrix':
data = pd.DataFrame(data, columns=self.original_labels, index=self.original_labels)
stx(project_name, filename, data)
else:
data = pd.DataFrame(data)
stx(project_name, filename, data)
def __prepare_data(self, X_train, y_train, X_test, y_test):
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
print(train_loader, test_loader)
return train_loader, test_loader

View File

@ -5,9 +5,9 @@ from Qfunctions.loaData import load_data as dLoader
from sklearn.decomposition import PCA
def main():
projet_name = '###########' # 输入元数据文件夹名称
label_names =[] # 请在[]内输入每一个分类的名称
data = dLoader(projet_name, label_names, isDir=False)
projet_name = '20241112Numbers' # 输入元数据文件夹名称
label_names =['1', '2', '3', '4', '5', '6', '7' ,'8', '9'] # 请在[]内输入每一个分类的名称
data = dLoader(projet_name, label_names, isDir=False, fileClass='xls')
X_train, X_test, y_train, y_test, encoder = divSet(
data=data, labels=label_names, test_size= 0.3
)