在模型训练中,使用交叉验证(Cross-Validation)是确保参数选择合理、模型泛化能力强的有效方式。具体来说,通过 网格搜索(Grid Search) 或 随机搜索(Randomized Search) 与交叉验证结合,可以系统地尝试不同的参数组合并评估模型性能。
下面具体讲解如何使用交叉验证来选择最优的 penalty
和 solver
参数,从而提升模型的准确率。
步骤 1:准备数据
假设你已经有了自变量(X)和因变量(y),并且数据已经被预处理好(如归一化、缺失值处理等)。使用 train_test_split
来划分训练集和测试集:
from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
步骤 2:定义参数网格
为了选择最佳的 penalty
和 solver
,我们可以先定义一个参数网格,包含不同的参数组合。比如针对 penalty
和 solver
参数,你可以定义如下网格:
param_grid = {'penalty': ['l1', 'l2'], # L1 和 L2 正则化'solver': ['liblinear', 'saga', 'lbfgs'], # 选择支持 L1 和 L2 正则化的优化器'max_iter': [1000, 2000] # 迭代次数
}
步骤 3:使用 GridSearchCV 进行交叉验证
GridSearchCV
会自动遍历你定义的参数组合,并在每个组合上进行交叉验证,最终返回表现最好的参数组合。
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression# 定义逻辑回归模型
model = LogisticRegression(multi_class='ovr')# 定义 GridSearchCV,使用5折交叉验证
grid_search = GridSearchCV(model, param_grid, cv=5, scoring='accuracy')# 在训练集上进行网格搜索
grid_search.fit(X_train, y_train)# 输出最好的参数组合
print("Best parameters found: ", grid_search.best_params_)
print("Best cross-validation score: {:.2f}".format(grid_search.best_score_))
步骤 4:在测试集上评估最优模型
在找到最优的参数组合后,我们可以在测试集上评估这个模型的性能,看看交叉验证选出的最优参数是否能提升模型的准确率。
# 使用最优参数的模型在测试集上评估
best_model = grid_search.best_estimator_
test_score = best_model.score(X_test, y_test)print("Test accuracy with best parameters: {:.2f}".format(test_score))
步骤 5:调整或扩展参数网格
如果发现模型的准确率还有提升空间,可以考虑:
- 增加更多参数组合:例如,调整
C
参数(正则化强度)、multi_class
参数(选择'multinomial'
)。 - 更换优化器或其他超参数:例如,增加
'newton-cg'
或'sag'
作为候选优化器。
param_grid_extended = {'penalty': ['l1', 'l2'],'solver': ['liblinear', 'saga', 'lbfgs', 'newton-cg'],'C': [0.01, 0.1, 1, 10], # 正则化强度'max_iter': [1000, 2000],'multi_class': ['ovr', 'multinomial'] # 多分类方式
}
然后重复上述的 GridSearchCV
步骤来进一步优化模型。
总结
- 定义参数网格:指定你要调优的参数(如
penalty
和solver
)。 - 使用交叉验证进行搜索:通过
GridSearchCV
遍历每个参数组合,并进行多次训练评估,找到最优参数组合。 - 在测试集上评估模型:用最优参数组合的模型预测测试集,评估准确率。
- 根据需要扩展网格或调整其他参数:进一步优化模型。
这样做的好处是可以确保参数的选择不是基于特定的训练集,而是经过交叉验证后具有较好的泛化能力,从而提高模型的准确率。