关闭 x
IT技术网
    技 采 号
    ITJS.cn - 技术改变世界
    • 实用工具
    • 菜鸟教程
    IT采购网 中国存储网 科技号 CIO智库

    IT技术网

    IT采购网
    • 首页
    • 行业资讯
    • 系统运维
      • 操作系统
        • Windows
        • Linux
        • Mac OS
      • 数据库
        • MySQL
        • Oracle
        • SQL Server
      • 网站建设
    • 人工智能
    • 半导体芯片
    • 笔记本电脑
    • 智能手机
    • 智能汽车
    • 编程语言
    IT技术网 - ITJS.CN
    首页 » 算法设计 »AdaBoost算法分析与实现

    AdaBoost算法分析与实现

    2014-12-04 00:00:00 出处:zhanlijun
    分享

    AdaBoost(自适应boosting,adaptive boosting)算法

    算法优缺点:

    优点:泛化错误率低,易编码,可用在绝大部分分类器上,无参数调整 缺点:对离群点敏感 适用数据类型:数值型和标称型

    元算法(meta algorithm)

    在分类问题中,我们可能不会只想用一个分类器,我们会考虑将分类器组合起来使用,这种方法称为集成方法(ensemble method)或元算法。元算法有多种形式,既可以是不同算法集成也可以是一种算法不同设置的集成。

    两种集成方式(bagging & boosting)

    bagging方法也称自举汇聚法(Bootstrap aggregating)。思路相当于是从数据集中随机抽样得到新的数据集,然后用新的数据集进行训练,最后的结果是新的数据集形成的分类器中的最多的类别。如从1000个样本组成的数据集中进行有放回的抽样5000次,得到5个新的训练集,将算法分别用到这五个训练集上从而得到五个分类器。 boosting则是一种通过串行训练得到结果的方法,在bagging中每个分类器的权重一样,而boosting中分类器的权重则与上一轮的成功度有关。

    AdaBoost

    是一种用的最多的boosting,想法就是下一次的迭代中,将上一次成功的样本的权重降低,失败的权重升高。权重变化方式:

    alpha(分类器权重)的变化: 数据权重变化: 正确分类的话: 错误分类的话

    实现思路:

    AdaBoost算法实现的是将弱分类器提升成为强分类器,所以这里我们首先要有一个弱分类器,代码中使用的是单层决策树,这也是使用的最多的弱分类器,然后我们就可以根据弱分类器构造出强分类器

    函数:

    stumpClassify(dataMatrix,dimen,threshVal,threshIneq)

    单层决策树的分类器,根据输入的值与阀值进行比较得到输出结果,因为是单层决策树,所以只能比较数据一个dimen的值

    buildStump(dataArr,classLabels,D)

    构造单层决策树,这部分的构造的思路和前面的决策树是一样的,只是这里的评价体系不是熵而是加权的错误率,这里的加权是通过数据的权重D来实现的,每一次build权重都会因上一次分类结果不同而不同。返回的单层决策树的相关信息存在字典结构中方便接下来的使用

    adaBoostTrainDS(dataArr,classLabels,numIt=40)

    AdaBoost的训练函数,用来将一堆的单层决策树组合起来形成结果。通过不断调整alpha和D来使得错误率不断趋近0,甚至最终达到0

    adaClassify(datToClass,classifierArr)

    分类函数,datToClass是要分类的数据,根据生成的一堆单层决策树的分类结果,加权得到最终结果。

    #coding=utf-8
    from numpy import *
    def loadSimpleData():
        dataMat = matrix([[1. , 2.1],
            [2. , 1.1],
            [1.3 , 1.],
            [1. , 1.],
            [2. , 1.]])
        classLabels = [1.0,1.0,-1.0,-1.0,1.0]
        return dataMat, classLabels
    
    def stumpClassify(dataMatrix,dimen,threshVal,threshIneq):
        retArry = ones((shape(dataMatrix)[0],1))
        if threshIneq == 'lt':
            retArry[dataMatrix[:,dimen] <= threshVal] = -1.0
        else:
            retArry[dataMatrix[:,dimen] > threshVal] = -1.0
        return retArry
    
    #D是权重向量
    def buildStump(dataArr,classLabels,D):
        dataMatrix = mat(dataArr)
        labelMat = mat(classLabels).T
        m,n = shape(dataMatrix)
        numSteps = 10.0#在特征所有可能值上遍历
        bestStump = {}#用于存储单层决策树的信息
        bestClasEst = mat(zeros((m,1)))
        minError = inf
        for i in range(n):#遍历所有特征
            rangeMin = dataMatrix[:,i].min()
            rangeMax = dataMatrix[:,i].max()
            stepSize = (rangeMax - rangeMin) / numSteps
            for j in range(-1,int(numSteps)+1):
                for inequal in ['lt','gt']:
                    threshVal = (rangeMin + float(j) * stepSize)#得到阀值
                    #根据阀值分类
                    predictedVals = stumpClassify(dataMatrix,i,threshVal,inequal)
                    errArr = mat(ones((m,1)))
                    errArr[predictedVals == labelMat] = 0
                    weightedError = D.T * errArr#不同样本的权重是不一样的
                    #print "split: dim %d, thresh %.2f, thresh ineqal: %s, the weighted error is %.3f" % (i, threshVal, inequal, weightedError)
                    if weightedError < minError:
                        minError = weightedError
                        bestClasEst = predictedVals.copy()
                        bestStump['dim'] = i 
                        bestStump['thresh'] = threshVal
                        bestStump['ineq'] = inequal
        return bestStump,minError,bestClasEst
    
    def adaBoostTrainDS(dataArr,classLabels,numIt=40):
        weakClassArr = []
        m =shape(dataArr)[0]
        D = mat(ones((m,1))/m)#初始化所有样本的权值一样
        aggClassEst = mat(zeros((m,1)))#每个数据点的估计值
        for i in range(numIt):
            bestStump,error,classEst = buildStump(dataArr,classLabels,D)
            #计算alpha,max(error,1e-16)保证没有错误的时候不出现除零溢出
            #alpha表示的是这个分类器的权重,错误率越低分类器权重越高
            alpha = float(0.5*log((1.0-error)/max(error,1e-16)))
            bestStump['alpha'] = alpha  
            weakClassArr.append(bestStump)
            expon = multiply(-1*alpha*mat(classLabels).T,classEst) #exponent for D calc, getting messy
            D = multiply(D,exp(expon))                              #Calc New D for next iteration
            D = D/D.sum()
            #calc training error of all classifiers, if this is 0 quit for loop early (use break)
            aggClassEst += alpha*classEst
            #print "aggClassEst: ",aggClassEst.T
            aggErrors = multiply(sign(aggClassEst) != mat(classLabels).T,ones((m,1)))
            errorRate = aggErrors.sum()/m
            print "total error: ",errorRate
            if errorRate == 0.0: 
                break
        return weakClassArr
    
    #dataToClass 表示要分类的点或点集
    def adaClassify(datToClass,classifierArr):
        dataMatrix = mat(datToClass)#do stuff similar to last aggClassEst in adaBoostTrainDS
        m = shape(dataMatrix)[0]
        aggClassEst = mat(zeros((m,1)))
        for i in range(len(classifierArr)):
            classEst = stumpClassify(dataMatrix,classifierArr[i]['dim'],
                                     classifierArr[i]['thresh'],
                                     classifierArr[i]['ineq'])#call stump classify
            aggClassEst += classifierArr[i]['alpha']*classEst
            print aggClassEst
        return sign(aggClassEst)
    
    def main():
        dataMat,classLabels = loadSimpleData()
        D = mat(ones((5,1))/5)
        classifierArr = adaBoostTrainDS(dataMat,classLabels,30)
        t = adaClassify([0,0],classifierArr)
        print t 
    
    if __name__ == '__main__':
        main()
    上一篇返回首页 下一篇

    声明: 此文观点不代表本站立场;转载务必保留本文链接;版权疑问请联系我们。

    别人在看

    帝国CMS7.5编辑器上传图片取消宽高的三种方法

    帝国cms如何自动生成缩略图的实现方法

    Windows 12即将到来,将彻底改变人机交互

    帝国CMS 7.5忘记登陆账号密码怎么办?可以phpmyadmin中重置管理员密码

    帝国CMS 7.5 后台编辑器换行,修改回车键br换行为p标签

    Windows 11 版本与 Windows 10比较,新功能一览

    Windows 11激活产品密钥收集及专业版激活方法

    如何从 Windows 11 中完全删除/卸载 OneNote?无解!

    抖音安全与信任开放日:揭秘推荐算法,告别单一标签依赖

    ultraedit编辑器打开文件时,总是提示是否转换为DOS格式,如何关闭?

    IT头条

    华为Pura80系列新机预热,余承东力赞其复杂光线下的视频拍摄实力

    01:28

    阿里千问3开源首战告捷:全球下载破千万,国产AI模型崛起新高度!

    01:22

    DeepSeek R1小版本试升级:网友实测编程能力已达到国际一线水平

    23:15

    NVIDIA 与 Dell 合作,大规模交付 Blackwell AI 系统

    20:52

    Cerebras 以最快的 Llama 4 Maverick 性能引领 LLM 推理竞赛

    20:51

    技术热点

    PHP中的随机性——你觉得自己幸运吗?

    搞定Ubuntu Linux下WPA无线上网

    Java使用内存映射实现大文件的上传

    MySQL安全性指南

    MySQL两项性能的基本测试浅谈

    教您使用UniqueIdentifier选取SQL Server主键

      友情链接:
    • IT采购网
    • 科技号
    • 中国存储网
    • 存储网
    • 半导体联盟
    • 医疗软件网
    • 软件中国
    • ITbrand
    • 采购中国
    • CIO智库
    • 考研题库
    • 法务网
    • AI工具网
    • 电子芯片网
    • 安全库
    • 隐私保护
    • 版权申明
    • 联系我们
    IT技术网 版权所有 © 2020-2025,京ICP备14047533号-20,Power by OK设计网

    在上方输入关键词后,回车键 开始搜索。Esc键 取消该搜索窗口。