Update Qnn model data preparation to include unsqueeze for tensor dimensions; replace Qmlp with QCNN in main.py; add diagram generation script

This commit is contained in:
newbie 2025-12-14 23:24:52 +08:00
parent 0667dd90d1
commit 4fea502e6b
3 changed files with 183 additions and 12 deletions

View File

@ -40,11 +40,11 @@ class Qnn(nn.Module):
def __prepare_data(self):
# 将data转换为tensor形式
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32)
X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32).unsqueeze(1)
self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train)
y_train_tensor = torch.tensor(self.y_train, dtype=torch.long)
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32)
X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32).unsqueeze(1)
self.y_test = self.LABEL_ENCODER.transform(self.y_test)
y_test_tensor = torch.tensor(self.y_test, dtype=torch.long)

16
main.py
View File

@ -1,4 +1,4 @@
from Qtorch.Models.Qmlp import Qmlp
from Qtorch.Models.Qcnn import QCNN
from Qfunctions.divSet import divSet
from Qfunctions.loaData import load_data
from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx
@ -14,17 +14,17 @@ def main():
data=data, labels=label_names, test_size= 0.3
)
model = Qmlp(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
hidden_layers = [16],
dropout_rate=0
)
# model = QCNN
# model = Qmlp(
# X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
# hidden_layers = [16],
# dropout_rate=0
# )
model = QCNN(
X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test,
dropout_rate=0
)
pca_2d, pca_3d = model.get_PCA()
model.fit(300)

171
test.py Normal file
View File

@ -0,0 +1,171 @@
import matplotlib.pyplot as plt
import matplotlib.patches as patches
def draw_diagram():
# Setup the figure
fig, ax = plt.subplots(figsize=(10, 8))
ax.set_xlim(0, 10)
ax.set_ylim(0, 8)
ax.axis('off') # Turn off axis
# Font settings
font_formula = {'family': 'sans-serif', 'weight': 'bold', 'size': 14}
font_label = {'family': 'sans-serif', 'weight': 'bold', 'size': 16}
# --- PART 1: TOP LEFT (The "Acrylate" Base) ---
# Replacing the SiO2 Block with an Acrylate Group structure
# We will draw a stylized Acrylic Acid molecule acting as the "Anchor"
# Label "A"
ax.text(1.0, 7.5, "A", **font_label)
# Label "Acrylate" (replaces SiO2)
ax.text(1.0, 4.5, "Acrylate\n(丙烯酸酯)", ha='center', va='center', fontsize=14, weight='bold')
# Draw Acrylate structure (Vertical orientation to mimic the wall)
# CH2=CH-C(=O)-OH
# C=C
ax.text(1.5, 6.5, "CH", **font_formula, ha='right')
ax.text(1.6, 6.4, "2", fontsize=10, weight='bold', ha='left') # subscript
ax.plot([1.5, 1.5], [6.3, 5.8], 'k-', lw=2) # Double bond line 1
ax.plot([1.4, 1.4], [6.3, 5.8], 'k-', lw=2) # Double bond line 2
ax.text(1.5, 5.6, "CH", **font_formula, ha='center')
ax.plot([1.5, 1.5], [5.4, 4.9], 'k-', lw=2) # Single bond
ax.text(1.5, 4.7, "C", **font_formula, ha='center')
# Carbonyl O
ax.plot([1.4, 1.1], [4.7, 4.7], 'k-', lw=2) # Double bond to O (sideways)
ax.plot([1.4, 1.1], [4.8, 4.8], 'k-', lw=2)
ax.text(0.9, 4.7, "O", **font_formula, ha='center', va='center')
# Hydroxyl OH (The reactive site)
ax.plot([1.6, 1.9], [4.7, 4.7], 'k-', lw=2)
ax.text(2.1, 4.7, "OH", **font_formula, ha='left', va='center')
# --- PART 2: TOP RIGHT (The Fluorinated Alcohol) ---
# Replacing FDTES with Perfluoroethylethanol (C2F5-CH2-CH2-OH)
# Structure: HO - CH2 - CH2 - CF2 - CF3
# The Plus Sign
ax.text(3.5, 5.5, "+", fontsize=40, weight='bold', color='#0070C0', ha='center')
# Start coordinates for alcohol
start_x = 4.5
y_level = 5.5
# HO-
ax.text(start_x, y_level, "HO", **font_formula, ha='right')
ax.plot([start_x + 0.1, start_x + 0.5], [y_level, y_level], 'k-', lw=2)
# -CH2-
ax.text(start_x + 0.8, y_level, "CH", **font_formula, ha='center')
ax.text(start_x + 1.05, y_level-0.1, "2", fontsize=10, weight='bold')
ax.plot([start_x + 1.2, start_x + 1.6], [y_level, y_level], 'k-', lw=2)
# -CH2-
ax.text(start_x + 1.9, y_level, "CH", **font_formula, ha='center')
ax.text(start_x + 2.15, y_level-0.1, "2", fontsize=10, weight='bold')
ax.plot([start_x + 2.3, start_x + 2.7], [y_level, y_level], 'k-', lw=2)
# -CF2- (Perfluoro group starts)
ax.text(start_x + 3.0, y_level, "C", **font_formula, ha='center')
# F on top
ax.plot([start_x + 3.0, start_x + 3.0], [y_level + 0.2, y_level + 0.5], 'k-', lw=2)
ax.text(start_x + 3.0, y_level + 0.6, "F", **font_formula, ha='center')
# F on bottom
ax.plot([start_x + 3.0, start_x + 3.0], [y_level - 0.2, y_level - 0.5], 'k-', lw=2)
ax.text(start_x + 3.0, y_level - 0.8, "F", **font_formula, ha='center')
ax.plot([start_x + 3.3, start_x + 3.7], [y_level, y_level], 'k-', lw=2)
# -CF3 (End of chain)
ax.text(start_x + 4.0, y_level, "C", **font_formula, ha='center')
# F on top
ax.plot([start_x + 4.0, start_x + 4.0], [y_level + 0.2, y_level + 0.5], 'k-', lw=2)
ax.text(start_x + 4.0, y_level + 0.6, "F", **font_formula, ha='center')
# F on bottom
ax.plot([start_x + 4.0, start_x + 4.0], [y_level - 0.2, y_level - 0.5], 'k-', lw=2)
ax.text(start_x + 4.0, y_level - 0.8, "F", **font_formula, ha='center')
# F on right
ax.plot([start_x + 4.2, start_x + 4.5], [y_level, y_level], 'k-', lw=2)
ax.text(start_x + 4.7, y_level, "F", **font_formula, ha='center')
# --- PART 3: THE ARROW ---
ax.arrow(5.0, 4.0, 0, -1.0, head_width=0.3, head_length=0.3, fc='#0070C0', ec='#0070C0', lw=3)
# --- PART 4: THE PRODUCT (Bottom) ---
# Fluorinated Acrylate Monomer
prod_y = 1.5
# Acrylate part (Left side of product)
ax.text(1.5, prod_y + 1.0, "CH", **font_formula, ha='right')
ax.text(1.6, prod_y + 0.9, "2", fontsize=10, weight='bold', ha='left')
ax.plot([1.5, 1.5], [prod_y + 0.8, prod_y + 0.3], 'k-', lw=2)
ax.plot([1.4, 1.4], [prod_y + 0.8, prod_y + 0.3], 'k-', lw=2)
ax.text(1.5, prod_y + 0.1, "CH", **font_formula, ha='center')
ax.plot([1.5, 1.5], [prod_y - 0.1, prod_y - 0.6], 'k-', lw=2)
ax.text(1.5, prod_y - 0.8, "C", **font_formula, ha='center')
# Carbonyl O
ax.plot([1.4, 1.1], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
ax.plot([1.4, 1.1], [prod_y - 0.7, prod_y - 0.7], 'k-', lw=2)
ax.text(0.9, prod_y - 0.8, "O", **font_formula, ha='center', va='center')
# Ester Oxygen (Replacing the OH group interaction)
ax.plot([1.7, 2.0], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
ax.text(2.2, prod_y - 0.8, "O", **font_formula, ha='center', va='center')
# Link to Spacer
ax.plot([2.4, 3.5], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) # Long bond to accommodate layout
# Fluorinated Chain (Right side of product)
# -CH2-
ax.text(3.8, prod_y - 0.8, "CH", **font_formula, ha='center', va='center')
ax.text(4.05, prod_y - 0.9, "2", fontsize=10, weight='bold')
ax.plot([4.2, 4.6], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
# -CH2-
ax.text(4.9, prod_y - 0.8, "CH", **font_formula, ha='center', va='center')
ax.text(5.15, prod_y - 0.9, "2", fontsize=10, weight='bold')
ax.plot([5.3, 5.7], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
# -CF2-
ax.text(6.0, prod_y - 0.8, "C", **font_formula, ha='center', va='center')
# F top/bottom
ax.plot([6.0, 6.0], [prod_y - 0.6, prod_y - 0.3], 'k-', lw=2)
ax.text(6.0, prod_y - 0.1, "F", **font_formula, ha='center')
ax.plot([6.0, 6.0], [prod_y - 1.0, prod_y - 1.3], 'k-', lw=2)
ax.text(6.0, prod_y - 1.6, "F", **font_formula, ha='center')
ax.plot([6.3, 6.7], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
# -CF3
ax.text(7.0, prod_y - 0.8, "C", **font_formula, ha='center', va='center')
# F top/bottom
ax.plot([7.0, 7.0], [prod_y - 0.6, prod_y - 0.3], 'k-', lw=2)
ax.text(7.0, prod_y - 0.1, "F", **font_formula, ha='center')
ax.plot([7.0, 7.0], [prod_y - 1.0, prod_y - 1.3], 'k-', lw=2)
ax.text(7.0, prod_y - 1.6, "F", **font_formula, ha='center')
# F right
ax.plot([7.2, 7.5], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2)
ax.text(7.7, prod_y - 0.8, "F", **font_formula, ha='center', va='center')
# Save the figure
plt.savefig("reaction_diagram.png", bbox_inches='tight', dpi=300)
plt.close()
if __name__ == "__main__":
draw_diagram()
print("Diagram generated as reaction_diagram.png")