CNN fist done
This commit is contained in:
parent
847bdae9f6
commit
34bba1a1c9
|
|
@ -1,43 +1,55 @@
|
||||||
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from Qtorch.Models.Qnn import Qnn
|
|
||||||
from sklearn.preprocessing import LabelEncoder
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
from Qtorch.Models.Qnn import Qnn
|
||||||
|
|
||||||
class QCNN(Qnn):
|
class QCNN(Qnn):
|
||||||
def __init__(self, X_train, y_train, X_test, y_test,
|
def __init__(self, X_train, y_train, X_test, y_test, labels=None, dropout_rate=0.3):
|
||||||
labels=None,
|
|
||||||
dropout_rate=0.3
|
|
||||||
):
|
|
||||||
super(QCNN, self).__init__()
|
super(QCNN, self).__init__()
|
||||||
|
|
||||||
self.LABEL_ENCODER = LabelEncoder()
|
self.LABEL_ENCODER = LabelEncoder()
|
||||||
|
|
||||||
self.X_train, self.y_train, self.X_test, self.y_test = X_train, y_train, X_test, y_test
|
self.X_train, self.y_train, self.X_test, self.y_test = X_train, y_train, X_test, y_test
|
||||||
|
|
||||||
self.labels = labels
|
self.labels = labels
|
||||||
|
|
||||||
input_size = X_train.shape[1]
|
input_size = X_train.shape[1] # 输入的长度
|
||||||
num_classes = len(set(y_train))
|
num_classes = len(set(y_train)) # 分类数
|
||||||
|
|
||||||
|
# 网络层:卷积层 + 池化层 + 全连接层
|
||||||
self.layers = nn.ModuleList()
|
self.layers = nn.ModuleList()
|
||||||
|
self.layers.append(nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3)) # 卷积层
|
||||||
|
self.layers.append(nn.MaxPool1d(kernel_size=2)) # 池化层
|
||||||
|
self.layers.append(nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3)) # 卷积层
|
||||||
|
self.layers.append(nn.MaxPool1d(kernel_size=2)) # 池化层
|
||||||
|
|
||||||
# Input layer to first Convolutional layer
|
# 计算展平后的大小
|
||||||
self.layers.append(nn.Conv1d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1))
|
conv_output_size = self._get_conv_output_size(input_size) # 卷积后的输出大小
|
||||||
self.layers.append(nn.ReLU())
|
print(f"Conv output size: {conv_output_size}") # 打印卷积后的输出大小
|
||||||
self.layers.append(nn.MaxPool1d(kernel_size=2, stride=2))
|
self.layers.append(nn.Linear(conv_output_size, 128)) # 全连接层
|
||||||
|
self.layers.append(nn.Linear(128, num_classes)) # 输出层
|
||||||
|
|
||||||
# Calculate the size after convolutions and pooling
|
|
||||||
conv_output_size = input_size // 4 # Assuming two pooling layers with stride 2
|
|
||||||
self.layers.append(nn.Linear(32 * conv_output_size, 128))
|
|
||||||
self.layers.append(nn.ReLU())
|
|
||||||
self.layers.append(nn.Dropout(dropout_rate))
|
|
||||||
|
|
||||||
# Output layer
|
|
||||||
self.layers.append(nn.Linear(128, num_classes))
|
|
||||||
self.__init_weights()
|
self.__init_weights()
|
||||||
|
|
||||||
def forward(self, x):
|
def _get_conv_output_size(self, input_size):
|
||||||
|
# 计算卷积后的输出尺寸
|
||||||
|
x = torch.randn(1, 1, input_size) # 创建一个假的输入张量
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
|
x = layer(x) # 通过每一层
|
||||||
|
return int(x.numel()) # 返回展平后的输出大小
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# 通过卷积和池化层
|
||||||
|
for layer in self.layers[:-2]: # 除去最后两个 Linear 层
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
|
# 展平卷积后的输出
|
||||||
|
x = x.view(x.size(0), -1) # 这样 x 会变成 (batch_size, conv_output_size)
|
||||||
|
|
||||||
|
# 通过全连接层
|
||||||
|
x = self.layers[-2](x)
|
||||||
|
x = self.layers[-1](x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def __init_weights(self):
|
def __init_weights(self):
|
||||||
|
|
|
||||||
|
|
@ -36,11 +36,11 @@ class Qnn(nn.Module):
|
||||||
def __prepare_data(self):
|
def __prepare_data(self):
|
||||||
|
|
||||||
# 将data转换为tensor形式
|
# 将data转换为tensor形式
|
||||||
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
|
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32).unsqueeze(1)
|
||||||
self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train)
|
self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train)
|
||||||
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
|
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
|
||||||
|
|
||||||
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32)
|
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32).unsqueeze(1)
|
||||||
self.y_test = self.LABEL_ENCODER.transform(self.y_test)
|
self.y_test = self.LABEL_ENCODER.transform(self.y_test)
|
||||||
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)
|
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)
|
||||||
|
|
||||||
|
|
|
||||||
9
main.py
9
main.py
|
|
@ -3,11 +3,13 @@ from Qtorch.Models.Qcnn import QCNN
|
||||||
from Qfunctions.divSet import divSet
|
from Qfunctions.divSet import divSet
|
||||||
from Qfunctions.loaData import load_data
|
from Qfunctions.loaData import load_data
|
||||||
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
|
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
|
||||||
|
import string
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
projet_name = '20241112Numbers' # 输入元数据文件夹名称
|
projet_name = '20241130 EMG-write' # 输入元数据文件夹名称
|
||||||
label_names =['1', '2', '3', '4', '5', '6', '7' ,'8', '9'] # 请在[]内输入每一个分类的名称
|
label_names = list(string.ascii_uppercase) # 请在[]内输入每一个分类的名称
|
||||||
data = load_data(projet_name, label_names, isDir=False, fileClass='xls')
|
print(label_names)
|
||||||
|
data = load_data(projet_name, label_names, isDir=False, fileClass='xlsx')
|
||||||
X_train, X_test, y_train, y_test, encoder = divSet(
|
X_train, X_test, y_train, y_test, encoder = divSet(
|
||||||
data=data, labels=label_names, test_size= 0.3
|
data=data, labels=label_names, test_size= 0.3
|
||||||
)
|
)
|
||||||
|
|
@ -17,6 +19,7 @@ def main():
|
||||||
# hidden_layers = [128],
|
# hidden_layers = [128],
|
||||||
# dropout_rate=0
|
# dropout_rate=0
|
||||||
# )
|
# )
|
||||||
|
|
||||||
model = QCNN(
|
model = QCNN(
|
||||||
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
|
||||||
dropout_rate=0
|
dropout_rate=0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue