import torch import torch.nn as nn import pandas as pd from abc import ABC, abstractmethod from sklearn.metrics import confusion_matrix as cm from torch.utils.data import DataLoader, TensorDataset from Qfunctions.divSet import divSet as ds from Qfunctions.saveToxlsx import save_to_xlsx as stx class Qnn(nn.Module, ABC): def __init__(self, labels=None): super(Qnn, self).__init__() self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 保存原始label, 混淆矩阵使用 self.original_labels = labels # 定义结果 self.result = { 'acc_and_loss' : { 'epoch' : [], 'loss': [], 'train_accuracy': [], 'test_accuracy': [], }, 'confusion_matrix': None, } def accuracy(self, output, target): pass # 定义损失函数 def hinge_loss(self, output, target): pass def confusion_matrix(self, test_outputs): predicted = torch.argmax(test_outputs, dim=1) true_label = torch.argmax(self.y_test, dim=1) return cm(predicted.cpu(), true_label.cpu()) def fit(self, epochs = 100): self.train_model(epochs)