Deeplearning/Qfunctions/divSet.py

29 lines
847 B
Python
Raw Normal View History

2024-10-07 09:54:32 +08:00
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
def divSet(data, labels = None, test_size=0.2, random_state=None):
encoder = LabelEncoder()
# 最后一列是标签
X = data.iloc[:, :-1]
y = data.iloc[:, -1]
if labels:
labels = encoder.fit_transform(labels)
else:
encoder.fit(y)
# 分割数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
# 标准化特征
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 编码标签
y_train = encoder.transform(y_train.values.reshape(-1, 1))
y_test = encoder.transform(y_test.values.reshape(-1, 1))
return X_train, X_test, y_train, y_test, encoder