数据挖掘-聚类-K-means算法Java实现

jopen 6年前

K-Means算法是最古老也是应用最广泛的聚类算法,它使用质心定义原型,质心是一组点的均值,通常该算法用于n维连续空间中的对象。


K-Means算法流程
step1:选择K个点作为初始质心
step2:repeat
               将每个点指派到最近的质心,形成K个簇
               重新计算每个簇的质心
            until 质心不在变化 

例如下图的样本集,初始选择是三个质心比较集中,但是迭代3次之后,质心趋于稳定,并将样本集分为3部分
数据挖掘-聚类-K-means算法Java实现
我们对每一个步骤都进行分析
step1:选择K个点作为初始质心
这一步首先要知道K的值,也就是说K是手动设置的,而不是像EM算法那样自动聚类成n个簇
其次,如何选择初始质心
     最简单的方式无异于,随机选取质心了,然后多次运行,取效果最好的那个结果。这个方法,简单但不见得有效,有很大的可能是得到局部最优。
     另一种复杂的方式是,随机选取一个质心,然后计算离这个质心最远的样本点,对于每个后继质心都选取已经选取过的质心的最远点。使用这种方式,可以确保质心是随机的,并且是散开的。

step2:repeat
               将每个点指派到最近的质心,形成K个簇
               重新计算每个簇的质心
            until 质心不在变化 
如何定义最近的概念,对于欧式空间中的点,可以使用欧式空间,对于文档可以用余弦相似性等等。对于给定的数据,可能适应与多种合适的邻近性度量。

其他问题
离群点的处理
离群点可能过度影响簇的发现,导致簇的最终发布会与我们的预想有较大出入,所以提前发现并剔除离群点是有必要的。
在我的工作中,是利用方差来剔除离群点,结果显示效果非常好。

簇分裂和簇合并
使用较大的K,往往会使得聚类的结果看上去更加合理,但很多情况下,我们并不想增加簇的个数。
这时可以交替采用簇分裂和簇合并。这种方式可以避开局部极小,并且能够得到具有期望个数簇的结果。

贴上代码java版,以后有时间写个python版的
抽象了点,簇,和距离
Point.class
public class Point {            private double x;            private double y;            private int id;            private boolean beyond;//标识是否属于样本                    public Point(int id, double x, double y) {                this.id = id;                this.x = x;                this.y = y;                this.beyond = true;            }                    public Point(int id, double x, double y, boolean beyond) {                this.id = id;                this.x = x;                this.y = y;                this.beyond = beyond;            }                    public double getX() {                return x;            }                    public double getY() {                return y;            }                    public int getId() {                return id;            }                    public boolean isBeyond() {                return beyond;            }                    @Override            public String toString() {                return "Point{" +                        "id=" + id +                        ", x=" + x +                        ", y=" + y +                        '}';            }                    @Override            public boolean equals(Object o) {                if (this == o) return true;                if (o == null || getClass() != o.getClass()) return false;                        Point point = (Point) o;                        if (Double.compare(point.x, x) != 0) return false;                if (Double.compare(point.y, y) != 0) return false;                        return true;            }                    @Override            public int hashCode() {                int result;                long temp;                temp = x != +0.0d ? Double.doubleToLongBits(x) : 0L;                result = (int) (temp ^ (temp >>> 32));                temp = y != +0.0d ? Double.doubleToLongBits(y) : 0L;                result = 31 * result + (int) (temp ^ (temp >>> 32));                return result;            }        }
Cluster.class
public class Cluster {            private int id;//标识            private Point center;//中心            private List members = new ArrayList();//成员                    public Cluster(int id, Point center) {                this.id = id;                this.center = center;            }                    public Cluster(int id, Point center, List members) {                this.id = id;                this.center = center;                this.members = members;            }                    public void addPoint(Point newPoint) {                if (!members.contains(newPoint))                    members.add(newPoint);                else                    throw new IllegalStateException("试图处理同一个样本数据!");            }                    public int getId() {                return id;            }                    public Point getCenter() {                return center;            }                    public void setCenter(Point center) {                this.center = center;            }                    public List getMembers() {                return members;            }                    @Override            public String toString() {                return "Cluster{" +                        "id=" + id +                        ", center=" + center +                        ", members=" + members +                        "}";            }        }
抽象的距离,可以具体实现为欧式,曼式或其他距离公式
public abstract class AbstractDistance {            abstract public double getDis(Point p1, Point p2);        }
点对
public class Distence implements Comparable {            private Point source;            private Point dest;            private double dis;            private AbstractDistance distance;                    public Distence(Point source, Point dest, AbstractDistance distance) {                this.source = source;                this.dest = dest;                this.distance = distance;                dis = distance.getDis(source, dest);            }                    public Point getSource() {                return source;            }                    public Point getDest() {                return dest;            }                    public double getDis() {                return dis;            }                    @Override            public int compareTo(Distence o) {                if (o.getDis() > dis)                    return -1;                else                    return 1;            }        }

核心实现类

public class KMeansCluster {            private int k;//簇的个数            private int num = 100000;//迭代次数            private List datas;//原始样本集            private String address;//样本集路径            private List data = new ArrayList();            private AbstractDistance distance = new AbstractDistance() {                @Override                public double getDis(Point p1, Point p2) {                    //欧几里德距离                    return Math.sqrt(Math.pow(p1.getX() - p2.getX(), 2) + Math.pow(p1.getY() - p2.getY(), 2));                }            };                    public KMeansCluster(int k, int num, String address) {                this.k = k;                this.num = num;                this.address = address;            }                    public KMeansCluster(int k, String address) {                this.k = k;                this.address = address;            }                    public KMeansCluster(int k, List datas) {                this.k = k;                this.datas = datas;            }                    public KMeansCluster(int k, int num, List datas) {                this.k = k;                this.num = num;                this.datas = datas;            }                    private void check() {                if (k == 0)                    throw new IllegalArgumentException("k must be the number > 0");                        if (address == null && datas == null)                    throw new IllegalArgumentException("program can't get real data");            }                    /**            * 初始化数据            *            * @throws java.io.FileNotFoundException            */            public void init() throws FileNotFoundException {                check();                //读取文件,init data                //处理原始数据                for (int i = 0, j = datas.size(); i < j; i++)                    data.add(new Point(i, datas.get(i), 0));            }                    /**            * 第一次随机选取中心点            *            * @return            */            public Set chooseCenter() {                Set center = new HashSet();                Random ran = new Random();                int roll = 0;                while (center.size() < k) {                    roll = ran.nextInt(data.size());                    center.add(data.get(roll));                }                return center;            }                    /**            * @param center            * @return            */            public List prepare(Set center) {                List cluster = new ArrayList();                Iterator it = center.iterator();                int id = 0;                while (it.hasNext()) {                    Point p = it.next();                    if (p.isBeyond()) {                        Cluster c = new Cluster(id++, p);                        c.addPoint(p);                        cluster.add(c);                    } else                        cluster.add(new Cluster(id++, p));                }                return cluster;            }                    /**            * 第一次运算,中心点为样本值            *            * @param center            * @param cluster            * @return            */            public List clustering(Set center, List cluster) {                Point[] p = center.toArray(new Point[0]);                TreeSet distence = new TreeSet();//存放距离信息                Point source;                Point dest;                boolean flag = false;                for (int i = 0, n = data.size(); i < n; i++) {                    distence.clear();                    for (int j = 0; j < center.size(); j++) {                        if (center.contains(data.get(i)))                            break;                                flag = true;                        // 计算距离                        source = data.get(i);                        dest = p[j];                        distence.add(new Distence(source, dest, distance));                    }                    if (flag == true) {                        Distence min = distence.first();                        for (int m = 0, k = cluster.size(); m < k; m++) {                            if (cluster.get(m).getCenter().equals(min.getDest()))                                cluster.get(m).addPoint(min.getSource());                                }                    }                    flag = false;                }                        return cluster;            }                    /**            * 迭代运算,中心点为簇内样本均值            *            * @param cluster            * @return            */            public List cluster(List cluster) {        //        double error;                Set lastCenter = new HashSet();                for (int m = 0; m < num; m++) {        //            error = 0;                    Set center = new HashSet();                    // 重新计算聚类中心                    for (int j = 0; j < k; j++) {                        List ps = cluster.get(j).getMembers();                        int size = ps.size();                        if (size < 3) {                            center.add(cluster.get(j).getCenter());                            continue;                        }                        // 计算距离                        double x = 0.0, y = 0.0;                        for (int k1 = 0; k1 < size; k1++) {                            x += ps.get(k1).getX();                            y += ps.get(k1).getY();                        }                        //得到新的中心点                        Point nc = new Point(-1, x / size, y / size, false);                        center.add(nc);                    }                    if (lastCenter.containsAll(center))//中心点不在变化,退出迭代                        break;                    lastCenter = center;                    // 迭代运算                    cluster = clustering(center, prepare(center));        //            for (int nz = 0; nz < k; nz++) {        //                error += cluster.get(nz).getError();//计算误差        //            }                }                return cluster;            }                    /**            * 输出聚类信息到控制台            *            * @param cs            */            public void out2console(List cs) {                for (int i = 0; i < cs.size(); i++) {                    System.out.println("No." + (i + 1) + " cluster:");                    Cluster c = cs.get(i);                    List p = c.getMembers();                    for (int j = 0; j < p.size(); j++) {                        System.out.println("\t" + p.get(j).getX() + " ");                    }                    System.out.println();                }            }        }
代码还没有仔细优化,执行的效率可能还存在一定的问题