mapreduce实现全局排序

jopen 10年前

直接附代码,说明都在源码里了。

package com.hadoop.totalsort;    import java.io.IOException;  import java.util.ArrayList;    import org.apache.hadoop.conf.Configuration;  import org.apache.hadoop.fs.FileSystem;  import org.apache.hadoop.fs.Path;  import org.apache.hadoop.io.LongWritable;  import org.apache.hadoop.io.NullWritable;  import org.apache.hadoop.io.SequenceFile;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapred.FileInputFormat;  import org.apache.hadoop.mapred.FileSplit;  import org.apache.hadoop.mapred.InputSplit;  import org.apache.hadoop.mapred.JobConf;  import org.apache.hadoop.mapred.LineRecordReader;  import org.apache.hadoop.mapred.RecordReader;  import org.apache.hadoop.mapred.Reporter;  import org.apache.hadoop.util.IndexedSortable;  import org.apache.hadoop.util.QuickSort;      public class SamplerInputFormat extends FileInputFormat<Text, Text> {             static final String PARTITION_FILENAME = "_partition.lst";        static final String SAMPLE_SIZE = "terasort.partitions.sample";        private static JobConf lastConf = null;        private static InputSplit[] lastResult = null;            static class TextSampler implements IndexedSortable {                public ArrayList<Text> records = new ArrayList<Text>();                public int compare(int arg0, int arg1) {                Text right = records.get(arg0);                Text left = records.get(arg1);                    return right.compareTo(left);            }                public void swap(int arg0, int arg1) {                Text right = records.get(arg0);                Text left = records.get(arg1);                    records.set(arg0, left);                records.set(arg1, right);            }                public void addKey(Text key) {                records.add(new Text(key));            }                          //将采集出来的key数据排序          public Text[] createPartitions(int numPartitions) {                int numRecords = records.size();                if (numPartitions > numRecords) {                    throw new IllegalArgumentException("Requested more partitions than input keys (" + numPartitions +                            " > " + numRecords + ")");                }                new QuickSort().sort(this, 0, records.size());                float stepSize = numRecords / (float) numPartitions;  //采集的时候应该是采了100条记录,从10个分片查找的,此处再取numPartitions-1条              Text[] result = new Text[numPartitions - 1];                for (int i = 1; i < numPartitions; ++i) {                    result[i - 1] = records.get(Math.round(stepSize * i));                }                return result;            }            }            public static void writePartitionFile(JobConf conf, Path partFile) throws IOException {            //前段代码从分片中采集数据,通过sampler.addKey存入TextSampler中的records数组       SamplerInputFormat inputFormat = new SamplerInputFormat();            TextSampler sampler = new TextSampler();            Text key = new Text();            Text value = new Text();                int partitions = conf.getNumReduceTasks(); // Reducer任务的个数             long sampleSize = conf.getLong(SAMPLE_SIZE, 100); // 采集数据-键值对的个数             InputSplit[] splits = inputFormat.getSplits(conf, conf.getNumMapTasks());// 获得数据分片             int samples = Math.min(10, splits.length);// 采集分片的个数   ,采集10个分片          long recordsPerSample = sampleSize / samples;// 每个分片采集的键值对个数             int sampleStep = splits.length / samples; // 采集分片的步长   ,总的分片个数/要采集的分片个数          long records = 0;                for (int i = 0; i < samples; i++) {  //1...10分片数              RecordReader<Text, Text> reader = inputFormat.getRecordReader(splits[sampleStep * i], conf, null);                while (reader.next(key, value)) {                    sampler.addKey(key);   //将key值增加到sampler的records数组                  records += 1;                    if ((i + 1) * recordsPerSample <= records) {  //目的是均匀采集各分片的条数,比如采集到第5个分片,那么记录条数应该小于5个分片应该的条数                      break;                    }                }            }            FileSystem outFs = partFile.getFileSystem(conf);            if (outFs.exists(partFile)) {                outFs.delete(partFile, false);            }            SequenceFile.Writer writer = SequenceFile.createWriter(outFs, conf, partFile, Text.class, NullWritable.class);            NullWritable nullValue = NullWritable.get();            for (Text split : sampler.createPartitions(partitions)) {  //调用createPartitions方法,排序采集出来的数据,并取partitions条              writer.append(split, nullValue);            }            writer.close();            }            static class TeraRecordReader implements RecordReader<Text, Text> {                private LineRecordReader in;            private LongWritable junk = new LongWritable();            private Text line = new Text();            private static int KEY_LENGTH = 10;                public TeraRecordReader(Configuration job, FileSplit split) throws IOException {                in = new LineRecordReader(job, split);            }                public void close() throws IOException {                in.close();            }                public Text createKey() {                // TODO Auto-generated method stub                 return new Text();            }                public Text createValue() {                return new Text();            }                public long getPos() throws IOException {                // TODO Auto-generated method stub                 return in.getPos();            }                public float getProgress() throws IOException {                // TODO Auto-generated method stub                 return in.getProgress();            }                public boolean next(Text arg0, Text arg1) throws IOException {                if (in.next(junk, line)) {   //调用父类方法,将value值赋给key                 // if (line.getLength() < KEY_LENGTH) {                         arg0.set(line);                        arg1.clear();    //                } else {     //                    byte[] bytes = line.getBytes(); // 默认知道读取要比较值的前10个字节 作为key     //                                                    // 后面的字节作为value;     //                    arg0.set(bytes, 0, KEY_LENGTH);     //                    arg1.set(bytes, KEY_LENGTH, line.getLength() - KEY_LENGTH);     //                }                     return true;                } else {                    return false;                }            }            }            @Override        public InputSplit[] getSplits(JobConf conf, int splits) throws IOException {            if (conf == lastConf) {                return lastResult;            }            lastConf = conf;            lastResult = super.getSplits(lastConf, splits);            return lastResult;            }            public org.apache.hadoop.mapred.RecordReader<Text, Text> getRecordReader(InputSplit arg0, JobConf arg1,                Reporter arg2) throws IOException {            return new TeraRecordReader(arg1, (FileSplit) arg0);        }        }  
package com.hadoop.totalsort;    import java.io.IOException;  import java.net.URI;  import java.util.ArrayList;  import java.util.List;    import org.apache.hadoop.conf.Configured;  import org.apache.hadoop.filecache.DistributedCache;  import org.apache.hadoop.fs.FileSystem;  import org.apache.hadoop.fs.Path;  import org.apache.hadoop.io.NullWritable;  import org.apache.hadoop.io.SequenceFile;  import org.apache.hadoop.io.Text;  import org.apache.hadoop.mapred.FileOutputFormat;  import org.apache.hadoop.mapred.JobClient;  import org.apache.hadoop.mapred.JobConf;  import org.apache.hadoop.mapred.Partitioner;  import org.apache.hadoop.mapred.TextOutputFormat;  import org.apache.hadoop.util.Tool;  import org.apache.hadoop.util.ToolRunner;    public class SamplerSort extends Configured implements Tool {             // 自定义的Partitioner         public static class TotalOrderPartitioner implements Partitioner<Text, Text> {                private Text[] splitPoints;                public TotalOrderPartitioner() {            }                public int getPartition(Text arg0, Text arg1, int arg2) {                // TODO Auto-generated method stub                 return findPartition(arg0);            }                public void configure(JobConf arg0) {                try {                    FileSystem fs = FileSystem.getLocal(arg0);                    Path partFile = new Path(SamplerInputFormat.PARTITION_FILENAME);                    splitPoints = readPartitions(fs, partFile, arg0); // 读取采集文件                 } catch (IOException ie) {                    throw new IllegalArgumentException("can't read paritions file", ie);                }                }                public int findPartition(Text key) // 分配可以到多个reduce             {                int len = splitPoints.length;                for (int i = 0; i < len; i++) {                    int res = key.compareTo(splitPoints[i]);                    if (res > 0 && i < len - 1) {                        continue;                    } else if (res == 0) {                        return i;                    } else if (res < 0) {                        return i;                    } else if (res > 0 && i == len - 1) {                        return i + 1;                    }                }                return 0;            }                private static Text[] readPartitions(FileSystem fs, Path p, JobConf job) throws IOException {                SequenceFile.Reader reader = new SequenceFile.Reader(fs, p, job);                List<Text> parts = new ArrayList<Text>();                Text key = new Text();                NullWritable value = NullWritable.get();                while (reader.next(key, value)) {                    parts.add(key);                }                reader.close();                return parts.toArray(new Text[parts.size()]);            }            }            public int run(String[] args) throws Exception {            JobConf job = (JobConf) getConf();           // job.set(name, value);             Path inputDir = new Path(args[0]);            inputDir = inputDir.makeQualified(inputDir.getFileSystem(job));            Path partitionFile = new Path(inputDir, SamplerInputFormat.PARTITION_FILENAME);                URI partitionUri = new URI(partitionFile.toString() +                    "#" + SamplerInputFormat.PARTITION_FILENAME);                SamplerInputFormat.setInputPaths(job, new Path(args[0]));            FileOutputFormat.setOutputPath(job, new Path(args[1]));                job.setJobName("SamplerTotalSort");            job.setJarByClass(SamplerSort.class);            job.setOutputKeyClass(Text.class);            job.setOutputValueClass(Text.class);            job.setInputFormat(SamplerInputFormat.class);            job.setOutputFormat(TextOutputFormat.class);            job.setPartitionerClass(TotalOrderPartitioner.class);            job.setNumReduceTasks(4);                SamplerInputFormat.writePartitionFile(job, partitionFile); // 数据采集并写入文件                 DistributedCache.addCacheFile(partitionUri, job); // 将这个文件作为共享文件             DistributedCache.createSymlink(job);            // SamplerInputFormat.setFinalSync(job, true);             JobClient.runJob(job);            return 0;        }            public static void main(String[] args) throws Exception {            int res = ToolRunner.run(new JobConf(), new SamplerSort(), args);            System.exit(res);        }        }