Kmean算法

发布时间:2023-12-19 19:30

Kmeans算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一。Kmeans算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。
假设要把样本集分为K个类别,算法描述如下:
(1)适当选择K个聚类的初始中心;
(2)在第K次迭代中,对任意一个样本,求其到K个中心的距离,将该样本归到距离最短的中心所在的簇;
(3)重新计算各个簇(聚类)的中心;
(4)对于所有的K个聚类中心,如果利用(2)(3)反复迭代,重新计算新旧中心的距离,若距离不变或小于某个阀值,则迭代结束。
该算法的最大优势在于简洁和快速。算法的关键在于初始中心的选择和距离公式。
这里最关键的地方就是初始中心K的选择,这里的选择好坏会很大程度上影响最终聚类的结果。
算法实现:下面简单利用距离计算方式演示聚类效果:
package com.wxshi.kmean;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;

/**
 * Kmeans算法演示
 * @author wxshi
 *
 */
public class Kmeans {

	private List> centers;
	private List> newCenters;
	private List>> clusterList;
	private int clusterNum = 5; //默认聚类的个数

	/**
	 * 默认构造不对外抛
	 */
	private Kmeans(){

	}

	public Kmeans(int clusterNum){
		if(clusterNum<=0){
			clusterNum = 5;
		}
		this.clusterNum = clusterNum;
		centers = new ArrayList>();
		newCenters = new ArrayList>();
		clusterList = new ArrayList>>();
	}

	/**
	 * 初始化簇,开始为空
	 * @param args
	 * @throws IOException
	 */
	public List>> initclusterList() {
		clusterList = new ArrayList>>();
		for (int i = 0; i < clusterNum; i++) {
			clusterList.add(new ArrayList>());
		}
		return clusterList;
	}

	/**
	 * 初始化聚类中心节点,随机选择
	 * 这里随便选择几个
	 * @param dataList
	 */
	private void initCenters(List> dataList){
		for (int i = 0; i < clusterNum; i++) {
			centers.add(dataList.get(i + 2));
			clusterList.add(new ArrayList>());
		}
	}

	/**
	 * 新旧中心切换
	 * 清空原来的簇中数据,重新放置数据
	 */
	private void replaceCenters() {
		centers = new ArrayList>(newCenters);
		newCenters = new ArrayList>();
		initclusterList();
	}



	/**
	 * 欧式距离计算
	 * @param element1
	 * @param element2
	 * @return
	 */
	private double distance(double element1,double element2){
		double distance = 0;
		distance = ((element1 - element2) / (element1 + element2)) * ((element1 - element2) / (element1 + element2));
		return distance;
	}

	/**
	 * 新旧聚类中心距离计算
	 * @return
	 */
	private double distanceOfCenters() {
		// 计算新旧中心之间的距离,当距离小于阈值时,聚类算法结束
		double distance = 0;
		for (int i = 0; i < clusterNum; i++) {
			for (int j = 0; j < centers.get(i).size(); j++) {// 计算两点之间的距离
				distance += distance(centers.get(i).get(j) , newCenters.get(i).get(j));
			}
		}
		return distance;
	}

	/**
	 * 重新计算聚类中心
	 */
	private void newCenters() {
		for (int i = 0; i < clusterNum; i++) {
			int len = clusterList.get(i).size();
			ArrayList tmpList = new ArrayList();
			for (int j = 0; j < centers.get(0).size(); j++) {
				double sum = 0;
				for (int t = 0; t < len; t++) {
					sum += clusterList.get(i).get(t).get(j);
				}
				tmpList.add(sum / len);
			}
			newCenters.add(tmpList);
		}
	}

	/**
	 * 核心方法
	 * 迭代簇,将距离最近的节点加入簇
	 * @param dataList
	 */
	private void intoCuster(List> dataList){
		for (int i = 0; i < dataList.size(); i++) {
			double minDistance = 99999999;
			int centerIndex = -1;
			for (int j = 0; j < clusterNum; j++) {// 计算最近距离
				double currentDistance = 0;
				for (int t = 0; t < centers.get(j).size(); t++) {// 计算两点之间的距离
					currentDistance += distance(centers.get(j).get(t) , dataList.get(i).get(t)) ;
				}
				if (minDistance > currentDistance) {
					minDistance = currentDistance;
					centerIndex = j;
				}
			}
			clusterList.get(centerIndex).add(dataList.get(i));
		}
	}

	/**
	 * 读取文件,获取数据
	 * @param dir
	 * @return
	 */
	public List> readFile(String dir) {
		List> dataList = new ArrayList>();
		try {
			BufferedReader br = new BufferedReader(new InputStreamReader(new FileInputStream(\"wine.txt\")));
			String data = null;
			while ((data = br.readLine()) != null) {
				String[] fields = data.split(\",\");
				List tmpList = new ArrayList();
				for (int i = 0; i < fields.length; i++) {
					tmpList.add(Double.parseDouble(fields[i]));
				}
				dataList.add((ArrayList) tmpList);
			}
			br.close();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return dataList;
	}

	/**
	 * 打印结果
	 */
	private void print() {
		for (int i = 0; i < clusterNum; i++) {
			System.out.println(\"\\nCluster: \" + (i + 1) + \"   size: \" + clusterList.get(i).size() + \" :\\n\");
			for (int j = 0; j < clusterList.get(i).size(); j++) {
				System.out.println(clusterList.get(i).get(j));
			}
		}
	}

	/**
	 * @param args
	 * @throws IOException
	 */
	public static void main(String[] args) throws IOException {

		Kmeans kmeans = new Kmeans(5);

		// 读入原始数据
		List> dataList = kmeans.readFile(\"wine.txt\");

		// 随机确定K个初始聚类中心
		kmeans.initCenters(dataList);

		// 进行若干次迭代,直到聚类中心稳定
		while (true) {
			kmeans.intoCuster(dataList);
			kmeans.newCenters();
			double distance = kmeans.distanceOfCenters();

			// 小于阈值时,结束循环
			if (distance == 0) {
				break;
			}
			// 否则,新的中心来代替旧的中心,进行下一轮迭代
			else {
				kmeans.replaceCenters();
			}
		}

		kmeans.print();
	}
}
下面代码摘自网页中,这里做个比较:
import java.util.ArrayList;
import java.util.Random;

/**
 * K均值聚类算法
 */
public class Kmeans2 {

    private int k;// 分成多少簇
    private int m;// 迭代次数
    private int dataSetLength;// 数据集元素个数,即数据集的长度
    private ArrayList dataSet;// 数据集链表
    private ArrayList center;// 中心链表
    private ArrayList> cluster; // 簇
    private ArrayList jc;// 误差平方和,k越接近dataSetLength,误差越小
    private Random random;

    /**
     * 设置需分组的原始数据集
     * @param dataSet
     */
    public void setDataSet(ArrayList dataSet) {
        this.dataSet = dataSet;
    }

    /**
     * 获取结果分组
     * @return 结果集
     */
    public ArrayList> getCluster() {
        return cluster;
    }

    /**
     * 构造函数,传入需要分成的簇数量
     * @param k
     *            簇数量,若k<=0时,设置为1,若k大于数据源的长度时,置为数据源的长度
     */
    public Kmeans2(int k) {
        if (k <= 0) {
            k = 1;
        }
        this.k = k;
    }

    /**
     * 初始化
     */
    private void init() {
        m = 0;
        random = new Random();
        if (dataSet == null || dataSet.size() == 0) {
            initDataSet();
        }
        dataSetLength = dataSet.size();
        if (k > dataSetLength) {
            k = dataSetLength;
        }
        center = initCenters();
        cluster = initCluster();
        jc = new ArrayList();
    }

    /**
     * 如果调用者未初始化数据集,则采用内部测试数据集
     */
    private void initDataSet() {
        dataSet = new ArrayList();
        // 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
        float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
                { 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
                { 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } };

        for (int i = 0; i < dataSetArray.length; i++) {
            dataSet.add(dataSetArray[i]);
        }
    }

    /**
     * 初始化中心数据链表,分成多少簇就有多少个中心点
     *
     * @return 中心点集
     */
    private ArrayList initCenters() {
        ArrayList center = new ArrayList();
        int[] randoms = new int[k];
        boolean flag;
        int temp = random.nextInt(dataSetLength);
        randoms[0] = temp;
        for (int i = 1; i < k; i++) {
            flag = true;
            while (flag) {
                temp = random.nextInt(dataSetLength);
                int j = 0;
                while (j < i) {
                    if (temp == randoms[j]) {
                        break;
                    }
                    j++;
                }
                if (j == i) {
                    flag = false;
                }
            }
            randoms[i] = temp;
        }
        for (int i = 0; i < k; i++) {
            center.add(dataSet.get(randoms[i]));// 生成初始化中心链表
        }
        return center;
    }

    /**
     * 初始化簇集合
     *
     * @return 一个分为k簇的空数据的簇集合
     */
    private ArrayList> initCluster() {
        ArrayList> cluster = new ArrayList>();
        for (int i = 0; i < k; i++) {
            cluster.add(new ArrayList());
        }
        return cluster;
    }

    /**
     * 计算两个点之间的距离
     *
     * @param element
     *            点1
     * @param center
     *            点2
     * @return 距离
     */
    private float distance(float[] element, float[] center) {
        float distance = 0.0f;
        float x = element[0] - center[0];
        float y = element[1] - center[1];
        float z = x * x + y * y;
        distance = (float) Math.sqrt(z);
        return distance;
    }

    /**
     * 获取距离集合中最小距离的位置
     *
     * @param distance
     *            距离数组
     * @return 最小距离在距离数组中的位置
     */
    private int minDistance(float[] distance) {
        float minDistance = distance[0];
        int minLocation = 0;
        for (int i = 1; i < distance.length; i++) {
            if (distance[i] < minDistance) {
                minDistance = distance[i];
                minLocation = i;
            } else if (distance[i] == minDistance) // 如果相等,随机返回一个位置
            {
                if (random.nextInt(10) < 5) {
                    minLocation = i;
                }
            }
        }
        return minLocation;
    }

    /**
     * 核心,将当前元素放到最小距离中心相关的簇中
     */
    private void clusterSet() {
        float[] distance = new float[k];
        for (int i = 0; i < dataSetLength; i++) {
            for (int j = 0; j < k; j++) {
                distance[j] = distance(dataSet.get(i), center.get(j));
            }
            int minLocation = minDistance(distance);
            cluster.get(minLocation).add(dataSet.get(i));// 核心,将当前元素放到最小距离中心相关的簇中

        }
    }

    /**
     * 求两点误差平方的方法
     *
     * @param element
     *            点1
     * @param center
     *            点2
     * @return 误差平方
     */
    private float errorSquare(float[] element, float[] center) {
        float x = element[0] - center[0];
        float y = element[1] - center[1];
        float errSquare = x * x + y * y;
        return errSquare;
    }

    /**
     * 计算误差平方和准则函数方法
     */
    private void countRule() {
        float jcF = 0;
        for (int i = 0; i < cluster.size(); i++) {
            for (int j = 0; j < cluster.get(i).size(); j++) {
                jcF += errorSquare(cluster.get(i).get(j), center.get(i));
            }
        }
        jc.add(jcF);
    }

    /**
     * 设置新的簇中心方法
     */
    private void setNewCenter() {
        for (int i = 0; i < k; i++) {
            int n = cluster.get(i).size();
            if (n != 0) {
                float[] newCenter = { 0, 0 };
                for (int j = 0; j < n; j++) {
                    newCenter[0] += cluster.get(i).get(j)[0];
                    newCenter[1] += cluster.get(i).get(j)[1];
                }
                // 设置一个平均值
                newCenter[0] = newCenter[0] / n;
                newCenter[1] = newCenter[1] / n;
                center.set(i, newCenter);
            }
        }
    }

    /**
     * 打印数据,测试用
     * @param dataArray
     *            数据集
     * @param dataArrayName
     *            数据集名称
     */
    public void printDataArray(ArrayList dataArray, String dataArrayName) {
        for (int i = 0; i < dataArray.size(); i++) {
            System.out.println(\"print:\" + dataArrayName + \"[\" + i + \"]={\" + dataArray.get(i)[0] + \",\" + dataArray.get(i)[1] + \"}\");
        }
        System.out.println(\"===================================\");
    }

    /**
     * Kmeans算法核心过程方法
     */
    private void kmeans() {
        init();
        // 循环分组,直到误差不变为止
        while (true) {
            clusterSet();
            countRule();
            // 误差不变了,分组完成
            if (m != 0) {
                if (jc.get(m) - jc.get(m - 1) == 0) {
                    break;
                }
            }
            setNewCenter();
            m++;
            cluster.clear();
            cluster = initCluster();
        }
    }

    /**
     * 执行算法
     */
    public void execute() {
        long startTime = System.currentTimeMillis();
        System.out.println(\"kmeans begins\");
        kmeans();
        long endTime = System.currentTimeMillis();
        System.out.println(\"kmeans running time=\" + (endTime - startTime) + \"ms\");
        System.out.println(\"kmeans ends\");
        System.out.println();
    }


    public  static void main(String[] args){
        //初始化一个Kmean对象,将k置为10
        Kmeans2 k=new Kmeans2(4);
        ArrayList dataSet=new ArrayList();

        dataSet.add(new float[]{1,22});
        dataSet.add(new float[]{3,333});
        dataSet.add(new float[]{3,4});
        dataSet.add(new float[]{5,6});
        dataSet.add(new float[]{8,9999});
        dataSet.add(new float[]{4,5});
        dataSet.add(new float[]{6,4});
        dataSet.add(new float[]{3,95});
        dataSet.add(new float[]{5,9});
        dataSet.add(new float[]{4,7777});
        dataSet.add(new float[]{1,9});
        dataSet.add(new float[]{7,844});
        //设置原始数据集
        k.setDataSet(dataSet);
        //执行算法
        k.execute();
        //得到聚类结果
        ArrayList> cluster=k.getCluster();
        //查看结果
        for(int i=0;i

大致思想不变:就是不断选择聚类中心,根据距离选择加入簇的节点,不断迭代,直到距离小于某个阀值或不变则聚类结束。

ItVuer - 免责声明 - 关于我们 - 联系我们

本网站信息来源于互联网,如有侵权请联系:561261067@qq.com

桂ICP备16001015号