From 4fea502e6b46bda09d16a87b50fdc68cb2ae1aee Mon Sep 17 00:00:00 2001 From: newbie Date: Sun, 14 Dec 2025 23:24:52 +0800 Subject: [PATCH] Update Qnn model data preparation to include unsqueeze for tensor dimensions; replace Qmlp with QCNN in main.py; add diagram generation script --- Qtorch/Models/Qnn.py | 4 +- main.py | 20 ++--- test.py | 171 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 183 insertions(+), 12 deletions(-) create mode 100644 test.py diff --git a/Qtorch/Models/Qnn.py b/Qtorch/Models/Qnn.py index ca646be..adcf79f 100644 --- a/Qtorch/Models/Qnn.py +++ b/Qtorch/Models/Qnn.py @@ -40,11 +40,11 @@ class Qnn(nn.Module): def __prepare_data(self): # 将data转换为tensor形式 - X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32) + X_train_tensor = torch.tensor(self.X_train, dtype=torch.float32).unsqueeze(1) self.y_train = self.LABEL_ENCODER.fit_transform(self.y_train) y_train_tensor = torch.tensor(self.y_train, dtype=torch.long) - X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32) + X_test_tensor = torch.tensor(self.X_test, dtype=torch.float32).unsqueeze(1) self.y_test = self.LABEL_ENCODER.transform(self.y_test) y_test_tensor = torch.tensor(self.y_test, dtype=torch.long) diff --git a/main.py b/main.py index f426171..fb27ed9 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,4 @@ -from Qtorch.Models.Qmlp import Qmlp +from Qtorch.Models.Qcnn import QCNN from Qfunctions.divSet import divSet from Qfunctions.loaData import load_data from Qfunctions.saveToxlsx import save_to_xlsx as save_to_xlsx @@ -14,16 +14,16 @@ def main(): data=data, labels=label_names, test_size= 0.3 ) - model = Qmlp( - X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test, - hidden_layers = [16], - dropout_rate=0 - ) + # model = Qmlp( + # X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test, + # hidden_layers = [16], + # dropout_rate=0 + # ) - # model = QCNN - # X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test, - # dropout_rate=0 - # ) + model = QCNN( + X_train=X_train, X_test=X_test, y_train=y_train, y_test= y_test, + dropout_rate=0 + ) pca_2d, pca_3d = model.get_PCA() diff --git a/test.py b/test.py new file mode 100644 index 0000000..3f64a74 --- /dev/null +++ b/test.py @@ -0,0 +1,171 @@ +import matplotlib.pyplot as plt +import matplotlib.patches as patches + +def draw_diagram(): + # Setup the figure + fig, ax = plt.subplots(figsize=(10, 8)) + ax.set_xlim(0, 10) + ax.set_ylim(0, 8) + ax.axis('off') # Turn off axis + + # Font settings + font_formula = {'family': 'sans-serif', 'weight': 'bold', 'size': 14} + font_label = {'family': 'sans-serif', 'weight': 'bold', 'size': 16} + + # --- PART 1: TOP LEFT (The "Acrylate" Base) --- + # Replacing the SiO2 Block with an Acrylate Group structure + # We will draw a stylized Acrylic Acid molecule acting as the "Anchor" + + # Label "A" + ax.text(1.0, 7.5, "A", **font_label) + + # Label "Acrylate" (replaces SiO2) + ax.text(1.0, 4.5, "Acrylate\n(丙烯酸酯)", ha='center', va='center', fontsize=14, weight='bold') + + # Draw Acrylate structure (Vertical orientation to mimic the wall) + # CH2=CH-C(=O)-OH + # C=C + ax.text(1.5, 6.5, "CH", **font_formula, ha='right') + ax.text(1.6, 6.4, "2", fontsize=10, weight='bold', ha='left') # subscript + + ax.plot([1.5, 1.5], [6.3, 5.8], 'k-', lw=2) # Double bond line 1 + ax.plot([1.4, 1.4], [6.3, 5.8], 'k-', lw=2) # Double bond line 2 + + ax.text(1.5, 5.6, "CH", **font_formula, ha='center') + + ax.plot([1.5, 1.5], [5.4, 4.9], 'k-', lw=2) # Single bond + + ax.text(1.5, 4.7, "C", **font_formula, ha='center') + + # Carbonyl O + ax.plot([1.4, 1.1], [4.7, 4.7], 'k-', lw=2) # Double bond to O (sideways) + ax.plot([1.4, 1.1], [4.8, 4.8], 'k-', lw=2) + ax.text(0.9, 4.7, "O", **font_formula, ha='center', va='center') + + # Hydroxyl OH (The reactive site) + ax.plot([1.6, 1.9], [4.7, 4.7], 'k-', lw=2) + ax.text(2.1, 4.7, "OH", **font_formula, ha='left', va='center') + + + # --- PART 2: TOP RIGHT (The Fluorinated Alcohol) --- + # Replacing FDTES with Perfluoroethylethanol (C2F5-CH2-CH2-OH) + # Structure: HO - CH2 - CH2 - CF2 - CF3 + + # The Plus Sign + ax.text(3.5, 5.5, "+", fontsize=40, weight='bold', color='#0070C0', ha='center') + + # Start coordinates for alcohol + start_x = 4.5 + y_level = 5.5 + + # HO- + ax.text(start_x, y_level, "HO", **font_formula, ha='right') + ax.plot([start_x + 0.1, start_x + 0.5], [y_level, y_level], 'k-', lw=2) + + # -CH2- + ax.text(start_x + 0.8, y_level, "CH", **font_formula, ha='center') + ax.text(start_x + 1.05, y_level-0.1, "2", fontsize=10, weight='bold') + ax.plot([start_x + 1.2, start_x + 1.6], [y_level, y_level], 'k-', lw=2) + + # -CH2- + ax.text(start_x + 1.9, y_level, "CH", **font_formula, ha='center') + ax.text(start_x + 2.15, y_level-0.1, "2", fontsize=10, weight='bold') + ax.plot([start_x + 2.3, start_x + 2.7], [y_level, y_level], 'k-', lw=2) + + # -CF2- (Perfluoro group starts) + ax.text(start_x + 3.0, y_level, "C", **font_formula, ha='center') + # F on top + ax.plot([start_x + 3.0, start_x + 3.0], [y_level + 0.2, y_level + 0.5], 'k-', lw=2) + ax.text(start_x + 3.0, y_level + 0.6, "F", **font_formula, ha='center') + # F on bottom + ax.plot([start_x + 3.0, start_x + 3.0], [y_level - 0.2, y_level - 0.5], 'k-', lw=2) + ax.text(start_x + 3.0, y_level - 0.8, "F", **font_formula, ha='center') + + ax.plot([start_x + 3.3, start_x + 3.7], [y_level, y_level], 'k-', lw=2) + + # -CF3 (End of chain) + ax.text(start_x + 4.0, y_level, "C", **font_formula, ha='center') + # F on top + ax.plot([start_x + 4.0, start_x + 4.0], [y_level + 0.2, y_level + 0.5], 'k-', lw=2) + ax.text(start_x + 4.0, y_level + 0.6, "F", **font_formula, ha='center') + # F on bottom + ax.plot([start_x + 4.0, start_x + 4.0], [y_level - 0.2, y_level - 0.5], 'k-', lw=2) + ax.text(start_x + 4.0, y_level - 0.8, "F", **font_formula, ha='center') + # F on right + ax.plot([start_x + 4.2, start_x + 4.5], [y_level, y_level], 'k-', lw=2) + ax.text(start_x + 4.7, y_level, "F", **font_formula, ha='center') + + + # --- PART 3: THE ARROW --- + ax.arrow(5.0, 4.0, 0, -1.0, head_width=0.3, head_length=0.3, fc='#0070C0', ec='#0070C0', lw=3) + + + # --- PART 4: THE PRODUCT (Bottom) --- + # Fluorinated Acrylate Monomer + + prod_y = 1.5 + + # Acrylate part (Left side of product) + ax.text(1.5, prod_y + 1.0, "CH", **font_formula, ha='right') + ax.text(1.6, prod_y + 0.9, "2", fontsize=10, weight='bold', ha='left') + + ax.plot([1.5, 1.5], [prod_y + 0.8, prod_y + 0.3], 'k-', lw=2) + ax.plot([1.4, 1.4], [prod_y + 0.8, prod_y + 0.3], 'k-', lw=2) + + ax.text(1.5, prod_y + 0.1, "CH", **font_formula, ha='center') + + ax.plot([1.5, 1.5], [prod_y - 0.1, prod_y - 0.6], 'k-', lw=2) + + ax.text(1.5, prod_y - 0.8, "C", **font_formula, ha='center') + + # Carbonyl O + ax.plot([1.4, 1.1], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) + ax.plot([1.4, 1.1], [prod_y - 0.7, prod_y - 0.7], 'k-', lw=2) + ax.text(0.9, prod_y - 0.8, "O", **font_formula, ha='center', va='center') + + # Ester Oxygen (Replacing the OH group interaction) + ax.plot([1.7, 2.0], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) + ax.text(2.2, prod_y - 0.8, "O", **font_formula, ha='center', va='center') + + # Link to Spacer + ax.plot([2.4, 3.5], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) # Long bond to accommodate layout + + # Fluorinated Chain (Right side of product) + # -CH2- + ax.text(3.8, prod_y - 0.8, "CH", **font_formula, ha='center', va='center') + ax.text(4.05, prod_y - 0.9, "2", fontsize=10, weight='bold') + ax.plot([4.2, 4.6], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) + + # -CH2- + ax.text(4.9, prod_y - 0.8, "CH", **font_formula, ha='center', va='center') + ax.text(5.15, prod_y - 0.9, "2", fontsize=10, weight='bold') + ax.plot([5.3, 5.7], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) + + # -CF2- + ax.text(6.0, prod_y - 0.8, "C", **font_formula, ha='center', va='center') + # F top/bottom + ax.plot([6.0, 6.0], [prod_y - 0.6, prod_y - 0.3], 'k-', lw=2) + ax.text(6.0, prod_y - 0.1, "F", **font_formula, ha='center') + ax.plot([6.0, 6.0], [prod_y - 1.0, prod_y - 1.3], 'k-', lw=2) + ax.text(6.0, prod_y - 1.6, "F", **font_formula, ha='center') + + ax.plot([6.3, 6.7], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) + + # -CF3 + ax.text(7.0, prod_y - 0.8, "C", **font_formula, ha='center', va='center') + # F top/bottom + ax.plot([7.0, 7.0], [prod_y - 0.6, prod_y - 0.3], 'k-', lw=2) + ax.text(7.0, prod_y - 0.1, "F", **font_formula, ha='center') + ax.plot([7.0, 7.0], [prod_y - 1.0, prod_y - 1.3], 'k-', lw=2) + ax.text(7.0, prod_y - 1.6, "F", **font_formula, ha='center') + # F right + ax.plot([7.2, 7.5], [prod_y - 0.8, prod_y - 0.8], 'k-', lw=2) + ax.text(7.7, prod_y - 0.8, "F", **font_formula, ha='center', va='center') + + # Save the figure + plt.savefig("reaction_diagram.png", bbox_inches='tight', dpi=300) + plt.close() + +if __name__ == "__main__": + draw_diagram() + print("Diagram generated as reaction_diagram.png") \ No newline at end of file