72 lines
2.5 KiB
Python
72 lines
2.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import pandas as pd
|
|
from abc import ABC, abstractmethod
|
|
from Qtorch.Functions import dsplit
|
|
from Qtorch.Functions import save_to_xlsx as stx
|
|
# from sklearn.metrics import confusion_matrix
|
|
|
|
class Qnn(nn.Module, ABC):
|
|
def __init__(self, data, labels=None, test_size=0.2, random_state=None):
|
|
super(Qnn, self).__init__()
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
self.original_labels = labels
|
|
|
|
# Split data
|
|
self.X_train, self.X_test, self.y_train, self.y_test, self.labels = dsplit(
|
|
data=data,
|
|
labels=labels,
|
|
test_size=test_size,
|
|
random_state=random_state
|
|
)
|
|
|
|
self.train_loader, self.test_loader = self._prepare_data()
|
|
|
|
self.result = {
|
|
'acc_and_loss': {
|
|
'epoch': [],
|
|
'loss': [],
|
|
'train_accuracy': [],
|
|
'test_accuracy': [],
|
|
},
|
|
'confusion_matrix': None,
|
|
}
|
|
|
|
@abstractmethod
|
|
def forward(self, x):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def train_model(self, epochs):
|
|
pass
|
|
|
|
def fit(self, epochs=100):
|
|
self.train_model(epochs)
|
|
|
|
def save(self, project_name):
|
|
for filename, data in self.result.items():
|
|
if filename == 'confusion_matrix':
|
|
data = pd.DataFrame(data, columns=self.original_labels, index=self.original_labels)
|
|
else:
|
|
data = pd.DataFrame(data)
|
|
stx(project_name, filename, data)
|
|
|
|
def _prepare_data(self):
|
|
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
|
|
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
|
|
|
|
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32)
|
|
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)
|
|
|
|
train_dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor)
|
|
test_dataset = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor)
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
|
|
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
|
|
|
|
return train_loader, test_loader
|
|
|
|
# def confusion_matrix(self, test_outputs):
|
|
# predicted = torch.argmax(test_outputs, dim=1)
|
|
# true_label = torch.argmax(self.y_test, dim=1)
|
|
# return confusion_matrix(predicted.cpu(), true_label.cpu()) |