import torch import torch.nn as nn import pandas as pd from abc import ABC, abstractmethod from Qtorch.Functions import dsplit from Qtorch.Functions import save_to_xlsx as stx # from sklearn.metrics import confusion_matrix class Qnn(nn.Module, ABC): def __init__(self, data, labels=None, test_size=0.2, random_state=None): super(Qnn, self).__init__() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') self.original_labels = labels # Split data self.X_train, self.X_test, self.y_train, self.y_test, self.labels = dsplit( data=data, labels=labels, test_size=test_size, random_state=random_state ) self.train_loader, self.test_loader = self._prepare_data() self.result = { 'acc_and_loss': { 'epoch': [], 'loss': [], 'train_accuracy': [], 'test_accuracy': [], }, 'confusion_matrix': None, } @abstractmethod def forward(self, x): pass @abstractmethod def train_model(self, epochs): pass def fit(self, epochs=100): self.train_model(epochs) def save(self, project_name): for filename, data in self.result.items(): if filename == 'confusion_matrix': data = pd.DataFrame(data, columns=self.original_labels, index=self.original_labels) else: data = pd.DataFrame(data) stx(project_name, filename, data) def _prepare_data(self): X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32) y_train_tensor = torch.tensor(self.y_train, dtype=torch.long) X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32) y_test_tensor = torch.tensor(self.y_test, dtype=torch.long) train_dataset = torch.utils.data.TensorDataset(X_train_tensor, y_train_tensor) test_dataset = torch.utils.data.TensorDataset(X_test_tensor, y_test_tensor) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False) return train_loader, test_loader # def confusion_matrix(self, test_outputs): # predicted = torch.argmax(test_outputs, dim=1) # true_label = torch.argmax(self.y_test, dim=1) # return confusion_matrix(predicted.cpu(), true_label.cpu())