感知机学习算法

本文是读李航博士《统计学习方法》第二章的笔记总结。
感知机是用于二类分类的线性分类器,如果数据线性不可分,我觉得可以采用提升数据维度的方法来使得数据在更高的维度上线性可分。如果要用于多类别分类,可以循环使用感知机,每次分出一个类。
本文主要探讨两个问题,1.感知机是如何用于数据线性分类的? 2.怎么训练一个感知机?
先说问题1,感知机,对于每一个输入向量x(x为一条数据,内有多个维度的信息如x1,x2,…xn),会反馈输出Y={+1,-1},借此来对数据进行分类。由输入X到输出Y中间发生了什么呢?
x通过sign函数转换为了y输出:
gzj-1

W为权重向量,b为偏置,w和b统称为感知机参数,wx+b是一个线性方程,可以理解为空间中的一个平面(当数据只有二维的时候就是二维平面上的一条线,如下图),那么落在这条线下方的数据就是y为-1 落在这条线上方的数据就是y为+1。
gzj-2

 

接下来说问题2 ,我们通过什么样的方式来得到这个平面呢?这里要引入一个概念,cost function 损失函数。举个比较简单的栗子来解释下什么是损失函数,假设小胖和小黑面前有一大堆西瓜,要挑出比较甜的那些瓜,两个人每次挑一个给路人甲尝路人甲判定甜不甜,那么两个人开始挑啊挑啊,他们制定了一个规则,如果挑出的西瓜不甜,就要给对方脸上贴一个白条。最后小黑脸上有3个白条,小胖脸上有10个白条。那么就说明小黑挑西瓜挑的好,为啥?因为白条少啊!这个白条就是损失函数,小黑和小胖就是两个分类器。由此可见,在挑西瓜的分类中,小黑的效果比小胖好。
在感知分类器中,cost function是在训练过程中不断累加的。我们将误分类的数据集喂给cost function,拿二维的数据来说,我们误分类点到线的距离作为cost function的值,高维度的数据就是点到平面的距离d。

gzj-3

L为损失函数,可以进一步简化,把分母去掉,M为误分类点集合,这就是感知学习的经验风险函数。
如果误分类点离得越远,那么“惩罚”的就应该越厉害,因为这个点分类的结果非常不好。
那么为什么要引入损失函数的概念呢?损失函数主要用于评估模型的优劣,损失函数越小,说明模型对训练集分类的效果越好(训练集的质量不在本章讨论范围)
因此,我们要最小化L,求得最优的w和b。这里采用的是梯度下降的算法,对每个误分类点进行更新,直到损失函数为0。
学习率一般设定为0至1之间,
算法流程如下:

  1. 选取w,b初值
  2. 在训练集中选取数据(xi,yi)
  3. 如果yi(w*xi+b)<=0:      gzj-4
  4. 转到2 直至训练集中没有误分类点

Java代码如下(用惯Python,一下子转写Java突然有点不习惯,写的很烂)

public static void main(String args[]){
		double [][]datax = {{1,2,3,4},{1,2,3,5},{2,2,3,4},{3,2,3,4}};
		int []datay = {1,1,-1,-1};
		double []w = {0,0,0,0};
		double []result = perceptron(0,w,0.5,datax,datay);
		for (int i = 0; i<result.length; i++){
			System.out.println(result[i]);
		}
		System.out.println(predict(result, new double[]{1,2,3,4}));
	}
	
/**
 * 
   * @Name: perceptron 感知机 训练函数
   * @Description: @param b 初始偏置
   * @Description: @param w 初始权重
   * @Description: @param ls 学习速率
   * @Description: @param datax 训练集X
   * @Description: @param datay X对应的Y
   * @Description: @return 返回一个参数列表result  result为(w0,w1,...wn,b)
 */
public static double[] perceptron (double b , double w[], double ls, double [][]datax, int []datay){
	double []parameter = new double[w.length+1];
	double temp = 0;
	int flag = 0;	//循环结束标志
	int number = 0;
	while (flag != 2) {
		System.out.println("第"+ (number++) +"次循环");
		//遍历训练集
		for (int i = 0; i<datax.length; i++){
			//数据与权重相乘
			for (int j = 0 ; j < datax[i].length ; j++){
				 temp += datax[i][j]*w[j];
			}
		    temp = temp + b;
		    //判定是否为误分类
			if (temp*datay[i] < 0 || temp*datay[i] == 0){
				flag = 1;
		    	for (int k = 0 ; k <w.length ; k++){
		    		w[k] = w[k]+ls*datay[i]*datax[i][k];
		    	}
		    	b = b +ls*datay[i];
			}else if (i == datax.length-1) {
				flag = 2;
			}
		}
	}
	
	for (int i = 0; i < parameter.length; i++) {
		if (i < parameter.length-1){
			parameter[i] = w[i];
		}else {
			parameter[i] = b;
		}
	}
	return parameter;
}

/**
 * 
   * @Name: predict 预测函数
   * @Description: @param parameter 参数w,b
   * @Description: @param x 输入测试
   * @Description: @return 返回测试结果 0在平面上,1 正例,-1 负例
 */
public static int predict(double []parameter,double []x){
	int result = 0;
	int temp = 0;
	for (int i = 0; i < x.length; i++) {
		temp += x[i]*parameter[i];
	}
	temp += parameter[parameter.length-1];
	if (temp > 0 ){
		result = 1;
	}
	else if (temp < 0){
		result = -1;
	}else if (temp == 0) {
		result = 0;
	}
	return result;
}

 

——Snake

snake

作者: snake

我们需要为这个社会做一点贡献,失去了才懂得去珍惜。

《感知机学习算法》有4个想法

发表评论

电子邮件地址不会被公开。 必填项已用*标注