Deeplearning/Qtorch/Models/Qnn.py
2024-10-07 09:54:32 +08:00

106 lines
2.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import pandas as pd
from abc import ABC, abstractmethod
from sklearn.metrics import confusion_matrix as cm
from torch.utils.data import DataLoader, TensorDataset
from Qfunctions.divSet import divSet as ds
from Qfunctions.saveToxlsx import save_to_xlsx as stx
class Qnn(nn.Module, ABC):
def __init__(self, data, labels=None, test_size=0.2, random_state=None):
super(Qnn, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 保存原始labe 混淆矩阵使用
self.original_labels = labels
# 划分训练集和测试集
X_train, X_test, y_train, y_test, self.labels = ds(
data=data,
labels=labels,
test_size=test_size,
random_state=random_state
)
self.train_loader, self.test_loader = self.__prepare_data(
X_train=X_train,
y_train=y_train,
X_test=X_test,
y_test=y_test
)
# 定义结果
self.result = {
'acc_and_loss' : {
'epoch' : [],
'loss': [],
'train_accuracy': [],
'test_accuracy': [],
},
'confusion_matrix': None,
}
def accuracy(self, output, target):
pass
# 定义损失函数
def hinge_loss(self, output, target):
pass
@abstractmethod
def train_model(self, train_loader, test_loader, epochs):
pass
def confusion_matrix(self, test_outputs):
predicted = torch.argmax(test_outputs, dim=1)
true_label = torch.argmax(self.y_test, dim=1)
return cm(predicted.cpu(), true_label.cpu())
def fit(self, epochs = 100):
self.train_model(epochs)
def save(self, project_name):
for filename, data in self.result.items():
if filename == 'confusion_matrix':
data = pd.DataFrame(data, columns=self.original_labels, index=self.original_labels)
stx(project_name, filename, data)
else:
data = pd.DataFrame(data)
stx(project_name, filename, data)
def __prepare_data(self, X_train, y_train, X_test, y_test):
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
print(train_loader, test_loader)
return train_loader, test_loader