Hadoop k-means 算法实现

12年前

经过昨天的准备工作,今天基本就可以编写整个k-means算法程序了。今天编写的时候遇到了一个问题,是combine操作时遇到的问题。除了这个问题基本都按照原来的思路进行。先说下我的思路吧。

准备工作:在上传数据文件到HDFS上之前,先应该产生一个中心文件,比如我的输入文件如下:

0.0 0.2 0.4  0.3 0.2 0.4  0.4 0.2 0.4  0.5 0.2 0.4  5.0 5.2 5.4  6.0 5.2 6.4  4.0 5.2 4.4  10.3 10.4 10.5  10.3 10.4 10.5  10.3 10.4 10.5
然后要产生中心文件,可以使用如下命令来操作:

(1)、获取文件的总行数:  wc data.txt 。可以得到文件的行数是 :10

(2)、因为我要分为三类,所以10/3=3,那么我取的行数就是1,3,6(这个行数可以自己选择,比如也可以直接去前三行 head -n 3 data.txt >centers.txt),然后使用如下命令:awk 'NR==1||NR==3||NR==6' data.txt > centers.txt,然后再把centers.txt上传到HDFS上就可以了。

(下面我使用的是前三行作为数据中心文件)

下面的程序中就不用设置要分 的类别和数据文件的维度数了,我在写这篇和前篇文章的时候参考了这篇文章:http://www.cnblogs.com/zhangchaoyang/articles/2634365.html,这篇里面要在代码中自己设置要分的类别以及数据文件的维度数。

下面是map-combine-reduce 操作:

map: map的setup()函数主要是读取中心文件把文件的中心点读入一个double[][]中,然后是map。数据转换为:

Text(包含数据的字符串)--》[index,DataPro(Text(包含数据文件的字符串),IntWritable(1))]

combine:

[index,DataPro(Text(包含数据文件的字符串),IntWritable(1))]-->[index,DataPro(Text(包含数据文件相同index的相加的结果的字符串),IntWritable(sum(1)))]

reduce: reduce的setup()函数主要是读取数据中心文件,然后取出其中的数据维度信息(在reduce操作中需要数组赋值需要知道数据维度),

[index,DataPro(Text(包含数据文件相同index的相加的结果的字符串),IntWritable(sum(1)))]--》[index,DataPro(Text(包含数据文件相同index的相加的结果的字符串),IntWritable(sum(1)))]--》[index,Text(相同index的数据相加的平均值)]

上面的是循环的过程,最后一个job任务是输出分类的结果。

下面贴代码:

KmeansDriver:

package org.fansy.date928;    import java.io.BufferedReader;  import java.io.IOException;  import java.io.InputStreamReader;    import org.apache.commons.logging.LogFactory;  import org.apache.commons.logging.Log;  //import org.apache.commons.logging.LogFactory;  import org.apache.hadoop.conf.Configuration;  import org.apache.hadoop.filecache.DistributedCache;  import org.apache.hadoop.fs.FSDataInputStream;  import org.apache.hadoop.fs.FileSystem;  //import org.apache.hadoop.fs.FileSystem;  import org.apache.hadoop.fs.Path;  import org.apache.hadoop.io.IntWritable;  import org.apache.hadoop.io.NullWritable;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapreduce.Job;  import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;  import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat;  import org.apache.hadoop.util.GenericOptionsParser;    public class KmeansDriver {     /**    *   k-means algorithm program      */   private static final String temp_path="hdfs://fansyPC:9000/user/fansy/date928/kmeans/temp_center/";   private static final String dataPath="hdfs://fansyPC:9000/user/fansy/input/smallkmeansdata";   private static final int iterTime=300;   private static int iterNum=1;   private static final double threadHold=0.01;      private static Log log=LogFactory.getLog(KmeansDriver.class);      public static void main(String[] args) throws IOException, ClassNotFoundException, InterruptedException {    // TODO Auto-generated method stub    Configuration conf=new Configuration();        // set the centers data file    Path centersFile=new Path("hdfs://fansyPC:9000/user/fansy/input/centers");    DistributedCache.addCacheFile(centersFile.toUri(), conf);        String[] otherArgs = new GenericOptionsParser(conf, args).getRemainingArgs();       if (otherArgs.length != 1) {         System.err.println("Usage: KmeansDriver <indatafile> ");         System.exit(2);       }       Job job = new Job(conf, "kmeans job 0");       job.setJarByClass(KmeansDriver.class);       job.setMapperClass(KmeansM.class);       job.setMapOutputKeyClass(IntWritable.class);    job.setMapOutputValueClass(DataPro.class);       job.setNumReduceTasks(1);       job.setCombinerClass(KmeansC.class);       job.setReducerClass(KmeansR.class);       job.setOutputKeyClass(NullWritable.class);       job.setOutputValueClass(Text.class);           FileInputFormat.addInputPath(job, new Path(dataPath));       FileOutputFormat.setOutputPath(job, new Path(temp_path+0+"/"));         if(!job.waitForCompletion(true)){        System.exit(1); // run error then exit       }       //  do iteration       boolean flag=true;    while(flag&&iterNum<iterTime){     Configuration conf1=new Configuration();          // set the centers data file     Path centersFile1=new Path(temp_path+(iterNum-1)+"/part-r-00000");  //  the new centers file     DistributedCache.addCacheFile(centersFile1.toUri(), conf1);     boolean iterflag=doIteration(conf1,iterNum);     if(!iterflag){      log.error("job fails");      System.exit(1);     }     //  set the flag based on the old centers and the new centers          Path oldCentersFile=new Path(temp_path+(iterNum-1)+"/part-r-00000");     Path newCentersFile=new Path(temp_path+iterNum+"/part-r-00000");     FileSystem fs1=FileSystem.get(oldCentersFile.toUri(),conf1);     FileSystem fs2=FileSystem.get(oldCentersFile.toUri(),conf1);     if(!(fs1.exists(oldCentersFile)&&fs2.exists(newCentersFile))){      log.info("the old centers and new centers should exist at the same time");      System.exit(1);     }     String line1,line2;     FSDataInputStream in1=fs1.open(oldCentersFile);     FSDataInputStream in2=fs2.open(newCentersFile);     InputStreamReader istr1=new InputStreamReader(in1);     InputStreamReader istr2=new InputStreamReader(in2);     BufferedReader br1=new BufferedReader(istr1);     BufferedReader br2=new BufferedReader(istr2);     double error=0.0;     while((line1=br1.readLine())!=null&&((line2=br2.readLine())!=null)){      String[] str1=line1.split("\t");      String[] str2=line2.split("\t");      for(int i=0;i<str1.length;i++){       error+=(Double.parseDouble(str1[i])-Double.parseDouble(str2[i]))*(Double.parseDouble(str1[i])-Double.parseDouble(str2[i]));      }     }     if(error<threadHold){      flag=false;     }     iterNum++;         }    // the last job , classify the data        Configuration conf2=new Configuration();    // set the centers data file    Path centersFile2=new Path(temp_path+(iterNum-1)+"/part-r-00000");  //  the new centers file    DistributedCache.addCacheFile(centersFile2.toUri(), conf2);    lastJob(conf2,iterNum);    System.out.println(iterNum);   }      public static boolean doIteration(Configuration conf,int iterNum) throws IOException, ClassNotFoundException, InterruptedException{    boolean flag=false;    Job job = new Job(conf, "kmeans job"+" "+iterNum);       job.setJarByClass(KmeansDriver.class);       job.setMapperClass(KmeansM.class);       job.setMapOutputKeyClass(IntWritable.class);    job.setMapOutputValueClass(DataPro.class);       job.setNumReduceTasks(1);       job.setCombinerClass(KmeansC.class);       job.setReducerClass(KmeansR.class);       job.setOutputKeyClass(NullWritable.class);       job.setOutputValueClass(Text.class);           FileInputFormat.addInputPath(job, new Path(dataPath));       FileOutputFormat.setOutputPath(job, new Path(temp_path+iterNum+"/"));         flag=job.waitForCompletion(true);    return flag;   }   public static void lastJob(Configuration conf,int iterNum) throws IOException, ClassNotFoundException, InterruptedException{    Job job = new Job(conf, "kmeans job"+" "+iterNum);       job.setJarByClass(KmeansDriver.class);       job.setMapperClass(KmeansLastM.class);       job.setMapOutputKeyClass(IntWritable.class);    job.setMapOutputValueClass(Text.class);       job.setNumReduceTasks(4);     //  job.setCombinerClass(KmeansC.class);       job.setReducerClass(KmeansLastR.class);       job.setOutputKeyClass(IntWritable.class);       job.setOutputValueClass(Text.class);           FileInputFormat.addInputPath(job, new Path(dataPath));       FileOutputFormat.setOutputPath(job, new Path(temp_path+iterNum+"/"));         job.waitForCompletion(true);   }    }

Mapper:
package org.fansy.date928;    import java.io.BufferedReader;  import java.io.FileReader;  import java.io.IOException;  import java.util.ArrayList;  import java.util.List;    import org.apache.commons.logging.Log;  import org.apache.commons.logging.LogFactory;  import org.apache.hadoop.filecache.DistributedCache;  import org.apache.hadoop.fs.Path;  import org.apache.hadoop.io.IntWritable;  import org.apache.hadoop.io.LongWritable;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapreduce.Mapper;    public class KmeansM extends Mapper<LongWritable,Text,IntWritable,DataPro>{   private static Log log=LogFactory.getLog(KmeansM.class);      private double[][] centers;   private int dimention_m;  //  this is the k    private int dimention_n;   //  this is the features            static enum Counter{Fansy_Miss_Records};   @Override   public void setup(Context context) throws IOException,InterruptedException{    Path[] caches=DistributedCache.getLocalCacheFiles(context.getConfiguration());    if(caches==null||caches.length<=0){     log.error("center file does not exist");     System.exit(1);    }    BufferedReader br=new BufferedReader(new FileReader(caches[0].toString()));    String line;    List<ArrayList<Double>> temp_centers=new ArrayList<ArrayList<Double>>();    ArrayList<Double> center=null;    //  get the file data    while((line=br.readLine())!=null){     center=new ArrayList<Double>();     String[] str=line.split("\t");    // String[] str=line.split("\\s+");     for(int i=0;i<str.length;i++){      center.add(Double.parseDouble(str[i]));     // center.add((double)Float.parseFloat(str[i]));     }     temp_centers.add(center);    }    try {     br.close();    } catch (Exception e) {     // TODO Auto-generated catch block     e.printStackTrace();    }    //  fill the centers     @SuppressWarnings("unchecked")    ArrayList<Double>[] newcenters=temp_centers.toArray(new ArrayList[]{});     dimention_m=temp_centers.size();     dimention_n=newcenters[0].size();    centers=new double[dimention_m][dimention_n];    for(int i=0;i<dimention_m;i++){     Double[] temp_double=newcenters[i].toArray(new Double[]{});     for(int j=0;j<dimention_n;j++){      centers[i][j]=temp_double[j];    //  System.out.print(temp_double[j]+",");     }    // System.out.println();    }   }           public void map(LongWritable key,Text value,Context context)throws IOException,InterruptedException{    String[] values=value.toString().split("\t");   // String[] values=value.toString().split("\\s+");    if(values.length!=dimention_n){     context.getCounter(Counter.Fansy_Miss_Records).increment(1);     return;    }    double[] temp_double=new double[values.length];    for(int i=0;i<values.length;i++){     temp_double[i]=Double.parseDouble(values[i]);    }    //  set the index    double distance=Double.MAX_VALUE;    double temp_distance=0.0;    int index=0;    for(int i=0;i<dimention_m;i++){     double[] temp_center=centers[i];     temp_distance=getEnumDistance(temp_double,temp_center);     if(temp_distance<distance){       index=i;      distance=temp_distance;     }    }    DataPro newvalue=new DataPro();    newvalue.set(value, new IntWritable(1));   // System.out.println("the map out:"+index+","+value);    context.write(new IntWritable(index), newvalue);       }   public static double getEnumDistance(double[] source,double[] other){  //  get the distance    double distance=0.0;    if(source.length!=other.length){     return Double.MAX_VALUE;    }    for(int i=0;i<source.length;i++){     distance+=(source[i]-other[i])*(source[i]-other[i]);    }    distance=Math.sqrt(distance);    return distance;   }  }

Combiner:
package org.fansy.date928;    import java.io.BufferedReader;  import java.io.FileReader;  import java.io.IOException;    import org.apache.commons.logging.Log;  import org.apache.commons.logging.LogFactory;  import org.apache.hadoop.filecache.DistributedCache;  import org.apache.hadoop.fs.Path;  import org.apache.hadoop.io.IntWritable;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapreduce.Reducer;    public class KmeansC extends Reducer<IntWritable,DataPro,IntWritable,DataPro> {   private static int dimension=0;      private static Log log =LogFactory.getLog(KmeansC.class);   // the main purpose of the sutup() function is to get the dimension of the original data   public void setup(Context context) throws IOException{    Path[] caches=DistributedCache.getLocalCacheFiles(context.getConfiguration());    if(caches==null||caches.length<=0){     log.error("center file does not exist");     System.exit(1);    }    BufferedReader br=new BufferedReader(new FileReader(caches[0].toString()));    String line;    while((line=br.readLine())!=null){     String[] str=line.split("\t");    // String[] str=line.split("\\s+");     dimension=str.length;     break;    }    try {     br.close();    } catch (Exception e) {     // TODO Auto-generated catch block     e.printStackTrace();    }   }         public void reduce(IntWritable key,Iterable<DataPro> values,Context context)throws InterruptedException, IOException{     double[] sum=new double[dimension];    int sumCount=0;    // operation two    for(DataPro val:values){     String[] datastr=val.getCenter().toString().split("\t");   //  String[] datastr=val.getCenter().toString().split("\\s+");     sumCount+=val.getCount().get();     for(int i=0;i<dimension;i++){      sum[i]+=Double.parseDouble(datastr[i]);     }    }    //  calculate the new centers  //  double[] newcenter=new double[dimension];    StringBuffer sb=new StringBuffer();    for(int i=0;i<dimension;i++){     sb.append(sum[i]+"\t");    }   // System.out.println("combine text:"+sb.toString());   // System.out.println("combine sumCount:"+sumCount);    DataPro newvalue=new DataPro();    newvalue.set(new Text(sb.toString()), new IntWritable(sumCount));    context.write(key, newvalue);   }  }

Reducer:
package org.fansy.date928;    import java.io.BufferedReader;  import java.io.FileReader;  import java.io.IOException;    import org.apache.commons.logging.Log;  import org.apache.commons.logging.LogFactory;  import org.apache.hadoop.filecache.DistributedCache;  import org.apache.hadoop.fs.Path;  import org.apache.hadoop.io.IntWritable;  import org.apache.hadoop.io.NullWritable;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapreduce.Reducer;    public class KmeansR extends Reducer<IntWritable,DataPro,NullWritable,Text> {   private static int dimension=0;      private static Log log =LogFactory.getLog(KmeansC.class);   // the main purpose of the sutup() function is to get the dimension of the original data   public void setup(Context context) throws IOException{    Path[] caches=DistributedCache.getLocalCacheFiles(context.getConfiguration());    if(caches==null||caches.length<=0){     log.error("center file does not exist");     System.exit(1);    }    BufferedReader br=new BufferedReader(new FileReader(caches[0].toString()));    String line;    while((line=br.readLine())!=null){     String[] str=line.split("\t");     dimension=str.length;     break;    }    try {     br.close();    } catch (Exception e) {     // TODO Auto-generated catch block     e.printStackTrace();    }   }   public void reduce(IntWritable key,Iterable<DataPro> values,Context context)throws InterruptedException, IOException{       double[] sum=new double[dimension];    int sumCount=0;    for(DataPro val:values){     String[] datastr=val.getCenter().toString().split("\t");    // String[] datastr=val.getCenter().toString().split("\\s+");     sumCount+=val.getCount().get();     for(int i=0;i<dimension;i++){      sum[i]+=Double.parseDouble(datastr[i]);     }    }    //  calculate the new centers  //  double[] newcenter=new double[dimension];    StringBuffer sb=new StringBuffer();    for(int i=0;i<dimension;i++){     sb.append(sum[i]/sumCount+"\t");    // sb.append(sum[i]/sumCount+"\\s+");    }    context.write(null, new Text(sb.toString()));   }  }

LastMapper:
package org.fansy.date928;    import java.io.BufferedReader;  import java.io.FileReader;  import java.io.IOException;  import java.util.ArrayList;  import java.util.List;    import org.apache.commons.logging.Log;  import org.apache.commons.logging.LogFactory;  import org.apache.hadoop.filecache.DistributedCache;  import org.apache.hadoop.fs.Path;  import org.apache.hadoop.io.IntWritable;  import org.apache.hadoop.io.LongWritable;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapreduce.Mapper;    public class KmeansLastM extends Mapper<LongWritable,Text,IntWritable,Text>{   private static Log log=LogFactory.getLog(KmeansLastM.class);      private double[][] centers;   private int dimention_m;  //  this is the k    private int dimention_n;   //  this is the features            static enum Counter{Fansy_Miss_Records};   @Override   public void setup(Context context) throws IOException,InterruptedException{    Path[] caches=DistributedCache.getLocalCacheFiles(context.getConfiguration());    if(caches==null||caches.length<=0){     log.error("center file does not exist");     System.exit(1);    }    @SuppressWarnings("resource")    BufferedReader br=new BufferedReader(new FileReader(caches[0].toString()));    String line;    List<ArrayList<Double>> temp_centers=new ArrayList<ArrayList<Double>>();    ArrayList<Double> center=null;    //  get the file data    while((line=br.readLine())!=null){     center=new ArrayList<Double>();     String[] str=line.split("\t");     for(int i=0;i<str.length;i++){      center.add(Double.parseDouble(str[i]));     }     temp_centers.add(center);    }    //  fill the centers     @SuppressWarnings("unchecked")    ArrayList<Double>[] newcenters=temp_centers.toArray(new ArrayList[]{});     dimention_m=temp_centers.size();     dimention_n=newcenters[0].size();    centers=new double[dimention_m][dimention_n];    for(int i=0;i<dimention_m;i++){     Double[] temp_double=newcenters[i].toArray(new Double[]{});     for(int j=0;j<dimention_n;j++){      centers[i][j]=temp_double[j];    //  System.out.print(temp_double[j]+",");     }   //  System.out.println();    }   }           public void map(LongWritable key,Text value,Context context)throws IOException,InterruptedException{    String[] values=value.toString().split("\t");   // String[] values=value.toString().split("\\s+");    if(values.length!=dimention_n){     context.getCounter(Counter.Fansy_Miss_Records).increment(1);     return;    }    double[] temp_double=new double[values.length];    for(int i=0;i<values.length;i++){     temp_double[i]=Double.parseDouble(values[i]);    }    //  set the index    double distance=Double.MAX_VALUE;    double temp_distance=0.0;    int index=0;    for(int i=0;i<dimention_m;i++){     double[] temp_center=centers[i];     temp_distance=getEnumDistance(temp_double,temp_center);     if(temp_distance<distance){       index=i;      distance=temp_distance;     }    }    context.write(new IntWritable(index), value);       }   public static double getEnumDistance(double[] source,double[] other){  //  get the distance    double distance=0.0;    if(source.length!=other.length){     return Double.MAX_VALUE;    }    for(int i=0;i<source.length;i++){     distance+=(source[i]-other[i])*(source[i]-other[i]);    }    distance=Math.sqrt(distance);    return distance;   }  }

LastReducer:
package org.fansy.date928;    import java.io.IOException;  import org.apache.hadoop.io.IntWritable;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapreduce.Reducer;    public class KmeansLastR extends Reducer<IntWritable,Text,IntWritable,Text> {      public void reduce(IntWritable key,Iterable<Text> values,Context context)throws InterruptedException, IOException{      //  output the data directly    for(Text val:values){     context.write(key, val);    }       }  }

上面就是全部的代码了,下面贴出结果:
0 0.0 0.2 0.4  1 0.3 0.2 0.4  1 0.4 0.2 0.4  1 0.5 0.2 0.4  2 5.0 5.2 5.4  2 6.0 5.2 6.4  2 4.0 5.2 4.4  2 10.3 10.4 10.5  2 10.3 10.4 10.5  2 10.3 10.4 10.5

由最终的结果可以看出分类的结果不是很好,所以说初始的数据中心点一定要选好才行。

下面说下我遇到的问题,刚开始的时候我获取数据维度的方法不是使用读入文件然后再取得相应的信息,而是按照下面的方法:

WrongCombine:

package org.fansy.date927;    import java.io.IOException;  import java.util.Iterator;    import org.apache.hadoop.io.IntWritable;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapreduce.Reducer;    public class KmeansC extends Reducer<IntWritable,DataPro,IntWritable,DataPro> {      public void reduce(IntWritable key,Iterable<DataPro> values,Context context)throws InterruptedException, IOException{    // get dimension first    Iterator<DataPro> iter=values.iterator();    int dimension=0;    for(DataPro val:values){     String[] datastr=val.getCenter().toString().split("\t");     dimension=datastr.length;     break;    }        double[] sum=new double[dimension];    int sumCount=0;    //  operation one    while(iter.hasNext()){     DataPro val=iter.next();     String[] datastr=val.getCenter().toString().split("\t");     sumCount+=val.getCount().get();     for(int i=0;i<dimension;i++){      sum[i]+=Double.parseDouble(datastr[i]);     }    }        // operation two    /*for(DataPro val:values){     String[] datastr=val.getCenter().toString().split("\t");     sumCount+=val.getCount().get();     for(int i=0;i<dimension;i++){      sum[i]+=Double.parseDouble(datastr[i]);     }    }*/    //  calculate the new centers  //  double[] newcenter=new double[dimension];    StringBuffer sb=new StringBuffer();    for(int i=0;i<dimension;i++){     sb.append(sum[i]+"\t");    }    System.out.println("combine text:"+sb.toString());    System.out.println("combine sumCount:"+sumCount);    DataPro newvalue=new DataPro();    newvalue.set(new Text(sb.toString()), new IntWritable(sumCount));    context.write(key, newvalue);   }  }

从第16到20行是我获得数据维度的方法,但是虽然维度获得了,但是,后面的操作就出问题了,从我加入的调试提示信息可以看出是出了什么问题了:
12/09/28 14:40:40 INFO mapred.Task:  Using ResourceCalculatorPlugin : org.apache.hadoop.util.LinuxResourceCalculatorPlugin@435e331b  12/09/28 14:40:40 INFO mapred.MapTask: io.sort.mb = 100  12/09/28 14:40:40 INFO mapred.MapTask: data buffer = 79691776/99614720  12/09/28 14:40:40 INFO mapred.MapTask: record buffer = 262144/327680  0.0,0.0,0.0,  5.0,5.0,5.0,  10.0,10.0,10.0,  the map out:0,0.0 0.2 0.4  the map out:0,0.3 0.2 0.4  the map out:0,0.4 0.2 0.4  the map out:0,0.5 0.2 0.4  the map out:1,5.0 5.2 5.4  the map out:1,6.0 5.2 6.4  the map out:1,4.0 5.2 4.4  the map out:2,10.3 10.4 10.5  the map out:2,10.3 10.4 10.5  the map out:2,10.3 10.4 10.5  12/09/28 14:40:40 INFO mapred.MapTask: Starting flush of map output  combine text:1.2 0.6000000000000001 1.2000000000000002   combine sumCount:3  combine text:10.0 10.4 10.8   combine sumCount:2  combine text:20.6 20.8 21.0   combine sumCount:2  12/09/28 14:40:40 INFO mapred.MapTask: Finished spill 0  12/09/28 14:40:40 INFO mapred.Task: Task:attempt_local_0001_m_000000_0 is done. And is in the process of commiting

有上面的信息可以看出首先map操作没有问题,但是combine操作时计算相同的index的个数不对,每个都少了1个,而且计算的总和也不对,所以我这样猜想,是不是通过这样的操作:
for(DataPro val:values){     String[] datastr=val.getCenter().toString().split("\t");     dimension=datastr.length;     break;    }
可能改变了values的指针或者其他的什么东西之类的,导致后面我再使用这样的操作读取的时候就会从第二个开始读了。这个就不知到了。