问题导读
1.什么是网格搜索?
2.网格搜索本文举了什么例子?
上一篇
机器学习教程 九-二元分类效果的评估方法
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19107
任何一种机器学习模型都附带很多参数,不同场景对应不同的最佳参数,手工尝试各种参数无疑浪费很多时间,scikit-learn帮我们实现了自动化,那就是网格搜索
网格搜索 这里的网格指的是不同参数不同取值交叉后形成的一个多维网格空间。比如参数a可以取1、2,参数b可以取3、4,参数c可以取5、6,那么形成的多维网格空间就是:
[mw_shl_code=bash,true]
1、3、5
1、3、6
1、4、5
1、4、6
2、3、5
2、3、6
2、4、5
2、4、6
[/mw_shl_code]
一共2*2*2=8种情况
网格搜索就是遍历这8种情况进行模型训练和验证,最终选择出效果最优的参数组合
用法举例
[mw_shl_code=bash,true]# coding:utf-8
import sys
reload(sys)
sys.setdefaultencoding( "utf-8" )
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model.logistic import LogisticRegression
from sklearn.grid_search import GridSearchCV
from sklearn.pipeline import Pipeline
# 构造样本,这块得多构造点,不然会报class不足的错误,因为gridsearch会拆分成小组
X = []
X.append("fuck you")
X.append("fuck you all")
X.append("hello everyone")
X.append("fuck me")
X.append("hello boy")
X.append("fuck you")
X.append("fuck you all")
X.append("hello everyone")
X.append("fuck me")
X.append("hello boy")
X.append("fuck you")
X.append("fuck you all")
X.append("hello everyone")
X.append("fuck me")
X.append("hello boy")
X.append("fuck you")
X.append("fuck you all")
X.append("hello everyone")
X.append("fuck me")
X.append("hello boy")
X.append("fuck you")
X.append("fuck you all")
X.append("hello everyone")
X.append("fuck me")
X.append("hello boy")
y = [1,0,1,0,1,1,0,1,0,1,1,0,1,0,1,1,0,1,0,1,1,0,1,0,1]
# 这是执行的序列,gridsearch是构造多进程顺序执行序列并比较结果
# 这里的vect和clf名字自己随便起,但是要和parameters中的前缀对应
pipeline = Pipeline([
('vect', TfidfVectorizer(stop_words='english')),
('clf', LogisticRegression())
])
# 这里面的max_features必须是TfidfVectorizer的参数, 里面的取值就是子进程分别执行所用
parameters = {
'vect__max_features': (3, 5),
}
# accuracy表示按精确度判断最优值
grid_search = GridSearchCV(pipeline, parameters, n_jobs = -1, verbose = 1, scoring = 'accuracy', cv = 3)
grid_search.fit(X, y)
print '最佳效果: %0.3f' % grid_search.best_score_
print '最优参数组合: '
best_parameters = grid_search.best_estimator_.get_params()
for param_name in sorted(parameters.keys()):
print('\t%s: %r' % (param_name, best_parameters[param_name]))[/mw_shl_code]
执行结果如下:
[mw_shl_code=bash,true]Fitting 3 folds for each of 2 candidates, totalling 6 fits
[Parallel(n_jobs=-1)]: Done 7 out of 6 | elapsed: 0.0s remaining: -0.0s
[Parallel(n_jobs=-1)]: Done 7 out of 6 | elapsed: 0.1s remaining: -0.0s
[Parallel(n_jobs=-1)]: Done 7 out of 6 | elapsed: 0.1s remaining: -0.0s
[Parallel(n_jobs=-1)]: Done 7 out of 6 | elapsed: 0.1s remaining: -0.0s
[Parallel(n_jobs=-1)]: Done 7 out of 6 | elapsed: 0.1s remaining: -0.0s
[Parallel(n_jobs=-1)]: Done 6 out of 6 | elapsed: 0.1s finished
最佳效果: 0.800
最优参数组合:
vect__max_features: 3[/mw_shl_code]
这里面并行启动了6个任务,最终判断出max_features的最优解值是3
来源网站shareditor
相关文章
机器学习教程 一-不懂这些线性代数知识 别说你是搞机器学习的
http://www.aboutyun.com/forum.php?mod=viewthread&tid=18997
机器学习教程 二-安装octave绘制3D函数图像
http://www.aboutyun.com/thread-19006-1-1.html
机器学习教程 三-用scikit-learn求解一元线性回归问题
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19020
机器学习教程 四-用scikit-learn求解多元线性回归问题
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19042
机器学习教程 五-用matplotlib绘制精美的图表
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19060
机器学习教程 六-用scikit-learn求解多项式回归问题
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19073
机器学习教程 七-用随机梯度下降法(SGD)做线性拟合
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19086
机器学习教程 八-用scikit-learn做特征提取
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19095
机器学习教程 九-二元分类效果的评估方法
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19107
机器学习教程十-用scikit-learn的网格搜索快速找到最优模型参数
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19120
机器学习教程 十一-用scikit-learn做聚类分析大数据
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19129
机器学习教程 十二-神经网络模型的原理 大数据
http://www.aboutyun.com/forum.php?mod=viewthread&tid=19339
链接http://www.shareditor.com/blogshow/?blogId=60
|