Deeplearning/remake/Qtorch/models/Qnn.py

72 lines
2.5 KiB
Python
Raw Normal View History

2024-10-07 09:54:32 +08:00
import torch
import torch.nn as nn
import pandas as pd
from abc import ABC, abstractmethod
from Qtorch.Functions import dsplit
from Qtorch.Functions import save_to_xlsx as stx
# from sklearn.metrics import confusion_matrix
class Qnn(nn.Module, ABC):
def __init__(self, data, labels=None, test_size=0.2, random_state=None):
super(Qnn, self).__init__()
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.original_labels = labels
# Split data
self.X_train, self.X_test, self.y_train, self.y_test, self.labels = dsplit(
data=data,
labels=labels,
test_size=test_size,
random_state=random_state
)
self.train_loader, self.test_loader = self._prepare_data()
self.result = {
'acc_and_loss': {
'epoch': [],
'loss': [],
'train_accuracy': [],
'test_accuracy': [],
},
'confusion_matrix': None,
}
@abstractmethod
def forward(self, x):
pass
@abstractmethod
def train_model(self, epochs):
pass
def fit(self, epochs=100):
self.train_model(epochs)
def save(self, project_name):
for filename, data in self.result.items():
if filename == 'confusion_matrix':
data = pd.DataFrame(data, columns=self.original_labels, index=self.original_labels)
else:
data = pd.DataFrame(data)
stx(project_name, filename, data)
def _prepare_data(self):
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)
train_dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
return train_loader, test_loader
# def confusion_matrix(self, test_outputs):
# predicted = torch.argmax(test_outputs, dim=1)
# true_label = torch.argmax(self.y_test, dim=1)
# return confusion_matrix(predicted.cpu(), true_label.cpu())