47 lines
1.1 KiB
Python
47 lines
1.1 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import pandas as pd
|
||
from abc import ABC, abstractmethod
|
||
|
||
from sklearn.metrics import confusion_matrix as cm
|
||
|
||
from torch.utils.data import DataLoader, TensorDataset
|
||
from Qfunctions.divSet import divSet as ds
|
||
from Qfunctions.saveToxlsx import save_to_xlsx as stx
|
||
|
||
|
||
class Qnn(nn.Module, ABC):
|
||
|
||
def __init__(self, labels=None):
|
||
super(Qnn, self).__init__()
|
||
|
||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
# 保存原始label, 混淆矩阵使用
|
||
self.original_labels = labels
|
||
|
||
# 定义结果
|
||
self.result = {
|
||
'acc_and_loss' : {
|
||
'epoch' : [],
|
||
'loss': [],
|
||
'train_accuracy': [],
|
||
'test_accuracy': [],
|
||
},
|
||
'confusion_matrix': None,
|
||
}
|
||
|
||
def accuracy(self, output, target):
|
||
pass
|
||
|
||
# 定义损失函数
|
||
def hinge_loss(self, output, target):
|
||
pass
|
||
|
||
def confusion_matrix(self, test_outputs):
|
||
predicted = torch.argmax(test_outputs, dim=1)
|
||
true_label = torch.argmax(self.y_test, dim=1)
|
||
return cm(predicted.cpu(), true_label.cpu())
|
||
|
||
def fit(self, epochs = 100):
|
||
self.train_model(epochs) |