结合PageRank算法用Java实现文本相似度

jopen 10年前

目标

尝试了一下把PageRank算法结合了文本相似度计算。直觉上是想把一个list里,和大家都比较靠拢的文本可能最后的PageRank值会比较大。因为如 果最后计算的PageRank值大,说明有比较多的文本和他的相似度值比较高,或者有更多的文本向他靠拢。这样是不是就可以得到一些相对核心的文本,或者 相对代表性的文本?如果是要在整堆文本里切分一些关键的词做token,那么每个token在每份文本里的权重就可以不一样,那么是否就可以得到比较核心 的token,来给这些文本打标签?当然,分词切词的时候都是要用工具过滤掉stopword的。

我也只是想尝试一下这个想法,就简单实现了整个过程。可能实现上还有问题。我的结果是最后大家的PageRank值都非常接近。如:

    5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638  

选5,10,20,50都差不多。都非常接近。主 要在设置PageRank定制迭代的那个DISTANCE值,值越接近0,迭代次数越多,经过很多次游走之后,文本之间的关系都很相近,各自的 pagerank值相差不大。如果调成0.5这样的级别,可能迭代了4次左右就停下来了,互相之间差距会大一些。具体看自己的需求来控制这个距离参数了

 

代码实现

文本之间的相似度计算用的是余弦距离,先哈希过。下面是计算两个List<String>的余弦距离代码:

    package dcd.academic.recommend;                import java.util.ArrayList;        import java.util.HashMap;        import java.util.Iterator;        import java.util.Map;                import dcd.academic.util.StdOutUtil;                public class CosineDis {                    public static double getSimilarity(ArrayList<String> doc1, ArrayList<String> doc2) {                if (doc1 != null && doc1.size() > 0 && doc2 != null && doc2.size() > 0) {                            Map<Long, int[]> AlgorithmMap = new HashMap<Long, int[]>();                            for (int i = 0; i < doc1.size(); i++) {                        String d1 = doc1.get(i);                        long sIndex = hashId(d1);                        int[] fq = AlgorithmMap.get(sIndex);                        if (fq != null) {                            fq[0]++;                        } else {                            fq = new int[2];                            fq[0] = 1;                            fq[1] = 0;                            AlgorithmMap.put(sIndex, fq);                        }                    }                            for (int i = 0; i < doc2.size(); i++) {                        String d2 = doc2.get(i);                        long sIndex = hashId(d2);                        int[] fq = AlgorithmMap.get(sIndex);                        if (fq != null) {                            fq[1]++;                        } else {                            fq = new int[2];                            fq[0] = 0;                            fq[1] = 1;                            AlgorithmMap.put(sIndex, fq);                        }                            }                            Iterator<Long> iterator = AlgorithmMap.keySet().iterator();                    double sqdoc1 = 0;                    double sqdoc2 = 0;                    double denominator = 0;                    while (iterator.hasNext()) {                        int[] c = AlgorithmMap.get(iterator.next());                        denominator += c[0] * c[1];                        sqdoc1 += c[0] * c[0];                        sqdoc2 += c[1] * c[1];                    }                            return denominator / Math.sqrt(sqdoc1 * sqdoc2);                } else {                    return 0;                }            }                    public static long hashId(String s) {                long seed = 131; // 31 131 1313 13131 131313 etc.. BKDRHash                long hash = 0;                for (int i = 0; i < s.length(); i++) {                    hash = (hash * seed) + s.charAt(i);                }                return hash;            }                    public static void main(String[] args) {                ArrayList<String> t1 = new ArrayList<String>();                ArrayList<String> t2 = new ArrayList<String>();                t1.add("sa");                t1.add("dfg");                t1.add("df");                        t2.add("gfd");                t2.add("sa");                                StdOutUtil.out(getSimilarity(t1, t2));            }        }  
</div> </div>

利用上面这个类,根据文本之间的相似度,为每份文本计算得到一个向量(最后要归一一下),用来初始化PageRank的起始矩阵。我用的数据是我 solr里的论文标题+摘要的文本,我是通过SolrjHelper这个类去取得了一个List<String>。你想替换的话把这部分换成 自己想测试的String列就可以了。下面是读取数据,生成向量给PageRank类的代码:

    package dcd.academic.recommend;                import java.io.IOException;        import java.net.UnknownHostException;        import java.util.ArrayList;        import java.util.List;                import dcd.academic.mongodb.MyMongoClient;        import dcd.academic.solrj.SolrjHelper;        import dcd.academic.util.StdOutUtil;        import dcd.academic.util.StringUtil;                import com.mongodb.BasicDBList;        import com.mongodb.BasicDBObject;        import com.mongodb.DBCollection;        import com.mongodb.DBCursor;        import com.mongodb.DBObject;                public class BtwPublication {                        public static final int NUM = 20;                        public static void main(String[] args) throws IOException{                BtwPublication bp = new BtwPublication();                //bp.updatePublicationForComma();                PageRank pageRank = new PageRank(bp.getPagerankS("random"));                pageRank.doPagerank();            }                        public double getDist(String pub1, String pub2) throws IOException {                if (pub1 != null && pub2 != null) {                    ArrayList<String> doc1 = StringUtil.getTokens(pub1);                    ArrayList<String> doc2 = StringUtil.getTokens(pub2);                    return CosineDis.getSimilarity(doc1, doc2);                } else {                    return 0;                }            }                    //  public List<Map<String, String>> getPubs(String name) {        //              //  }                        public List<List<Double>> getPagerankS(String text) throws IOException {                SolrjHelper helper = new SolrjHelper(1);                List<String> pubs = helper.getPubsByTitle(text, 0, NUM);                List<List<Double>> s = new ArrayList<List<Double>>();                for (String pub : pubs) {                    List<Double> tmp_row = new ArrayList<Double>();                    double total = 0.0;                    for (String other : pubs) {                        if (!pub.equals(other)) {                            double tmp = getDist(pub, other);                            tmp_row.add(tmp);                            total += tmp;                        } else {                            tmp_row.add(0.0);                        }                    }                    s.add(getNormalizedRow(tmp_row, total));                }                return s;            }                        public List<Double> getNormalizedRow(List<Double> row, double d) {                List<Double> res = new ArrayList<Double>();                for (int i = 0; i < row.size(); i ++) {                    res.add(row.get(i) / d);                }                StdOutUtil.out(res.toString());                return res;            }        }  
</div> </div> 最后这个是PageRank类,部分参数可以自己再设置:
    package dcd.academic.recommend;                import java.util.ArrayList;        import java.util.List;        import java.util.Random;                import dcd.academic.util.StdOutUtil;                public class PageRank {            private static final double ALPHA = 0.85;            private static final double DISTANCE = 0.0000001;            private static final double MUL = 10;                        public static int SIZE;            public static List<List<Double>> s;                        PageRank(List<List<Double>> s) {                this.SIZE = s.get(0).size();                this.s = s;            }                        public static void doPagerank() {                List<Double> q = new ArrayList<Double>();                for (int i = 0; i < SIZE; i ++) {                    q.add(new Random().nextDouble()*MUL);                }                System.out.println("初始的向量q为:");                printVec(q);                System.out.println("初始的矩阵G为:");                printMatrix(getG(ALPHA));                List<Double> pageRank = calPageRank(q, ALPHA);                System.out.println("PageRank为:");                printVec(pageRank);                System.out.println();            }                    /**            * 打印输出一个矩阵            *             * @param m            */            public static void printMatrix(List<List<Double>> m) {                for (int i = 0; i < m.size(); i++) {                    for (int j = 0; j < m.get(i).size(); j++) {                        System.out.print(m.get(i).get(j) + ", ");                    }                    System.out.println();                }            }                    /**            * 打印输出一个向量            *             * @param v            */            public static void printVec(List<Double> v) {                for (int i = 0; i < v.size(); i++) {                    System.out.print(v.get(i) + ", ");                }                System.out.println();            }                    /**            * 获得一个初始的随机向量q            *             * @param n            *            向量q的维数            * @return 一个随机的向量q,每一维是0-5之间的随机数            */            public static List<Double> getInitQ(int n) {                Random random = new Random();                List<Double> q = new ArrayList<Double>();                for (int i = 0; i < n; i++) {                    q.add(new Double(5 * random.nextDouble()));                }                return q;            }                    /**            * 计算两个向量的距离            *             * @param q1            *            第一个向量            * @param q2            *            第二个向量            * @return 它们的距离            */            public static double calDistance(List<Double> q1, List<Double> q2) {                double sum = 0;                        if (q1.size() != q2.size()) {                    return -1;                }                        for (int i = 0; i < q1.size(); i++) {                    sum += Math.pow(q1.get(i).doubleValue() - q2.get(i).doubleValue(),                            2);                }                return Math.sqrt(sum);            }                    /**            * 计算pagerank            *             * @param q1            *            初始向量            * @param a            *            alpha的值            * @return pagerank的结果            */            public static List<Double> calPageRank(List<Double> q1, double a) {                        List<List<Double>> g = getG(a);                List<Double> q = null;                while (true) {                    q = vectorMulMatrix(g, q1);                    double dis = calDistance(q, q1);                    System.out.println(dis);                    if (dis <= DISTANCE) {                        System.out.println("q1:");                        printVec(q1);                        System.out.println("q:");                        printVec(q);                        break;                    }                    q1 = q;                }                return q;            }                    /**            * 计算获得初始的G矩阵            *             * @param a            *            为alpha的值,0.85            * @return 初始矩阵G            */            public static List<List<Double>> getG(double a) {                List<List<Double>> aS = numberMulMatrix(s, a);                List<List<Double>> nU = numberMulMatrix(getU(), (1 - a) / SIZE);                List<List<Double>> g = addMatrix(aS, nU);                return g;            }                    /**            * 计算一个矩阵乘以一个向量            *             * @param m            *            一个矩阵            * @param v            *            一个向量            * @return 返回一个新的向量            */            public static List<Double> vectorMulMatrix(List<List<Double>> m,                    List<Double> v) {                if (m == null || v == null || m.size() <= 0                        || m.get(0).size() != v.size()) {                    return null;                }                        List<Double> list = new ArrayList<Double>();                for (int i = 0; i < m.size(); i++) {                    double sum = 0;                    for (int j = 0; j < m.get(i).size(); j++) {                        double temp = m.get(i).get(j).doubleValue()                                * v.get(j).doubleValue();                        sum += temp;                    }                    list.add(sum);                }                        return list;            }                    /**            * 计算两个矩阵的和            *             * @param list1            *            第一个矩阵            * @param list2            *            第二个矩阵            * @return 两个矩阵的和            */            public static List<List<Double>> addMatrix(List<List<Double>> list1,                    List<List<Double>> list2) {                List<List<Double>> list = new ArrayList<List<Double>>();                if (list1.size() != list2.size() || list1.size() <= 0                        || list2.size() <= 0) {                    return null;                }                for (int i = 0; i < list1.size(); i++) {                    list.add(new ArrayList<Double>());                    for (int j = 0; j < list1.get(i).size(); j++) {                        double temp = list1.get(i).get(j).doubleValue()                                + list2.get(i).get(j).doubleValue();                        list.get(i).add(new Double(temp));                    }                }                return list;            }                    /**            * 计算一个数乘以矩阵            *             * @param s            *            矩阵s            * @param a            *            double类型的数            * @return 一个新的矩阵            */            public static List<List<Double>> numberMulMatrix(List<List<Double>> s,                    double a) {                List<List<Double>> list = new ArrayList<List<Double>>();                        for (int i = 0; i < s.size(); i++) {                    list.add(new ArrayList<Double>());                    for (int j = 0; j < s.get(i).size(); j++) {                        double temp = a * s.get(i).get(j).doubleValue();                        list.get(i).add(new Double(temp));                    }                }                return list;            }                    /**            * 初始化U矩阵,全1            *             * @return U            */            public static List<List<Double>> getU() {                List<Double> row = new ArrayList<Double>();                for (int i = 0; i < SIZE; i ++) {                    row.add(new Double(1));                }                        List<List<Double>> s = new ArrayList<List<Double>>();                for (int j = 0; j < SIZE; j ++) {                    s.add(row);                }                return s;            }        }  
</div> </div>

下面是我一次实验结果的数据,我设置了五分文本,这样看起来比较短:

 
[0.0, 0.09968643574761415, 0.2601130421632277, 0.31094706119099713, 0.32925346089816093]    [0.1315115598803241, 0.0, 0.23650307622882252, 0.2827229880685279, 0.34926237582232544]    [0.13521235055030142, 0.09318868159350341, 0.0, 0.3996835314966943, 0.3719154363595009]    [0.1389453620825689, 0.0957614822411479, 0.34357346750710194, 0.0, 0.4217196881691813]    [0.14612484353723476, 0.11749453142051332, 0.31752920814285096, 0.4188514168994011, 0.0]    初始的向量q为:    8.007763265073303, 3.1232982446687387, 1.1722525763669134, 5.906625842576609, 9.019220483814852,     初始的矩阵G为:    0.030000000000000006, 0.11473347038547205, 0.2510960858387436, 0.2943050020123476, 0.30986544176343683,     0.14178482589827548, 0.030000000000000006, 0.23102761479449913, 0.2703145398582487, 0.3268730194489766,     0.1449304979677562, 0.10921037935447789, 0.030000000000000006, 0.36973100177219015, 0.3461281209055758,     0.14810355777018358, 0.11139725990497573, 0.3220374473810367, 0.030000000000000006, 0.38846173494380415,     0.15420611700664955, 0.12987035170743633, 0.29989982692142336, 0.38602370436449096, 0.030000000000000006,     8.215210604296416    2.1786836521210637    0.6343362349619535    0.19024536572818584    0.05836227176176904    0.018354791916908083    0.0059297512567364945    0.0019669982458251243    6.679891158687752E-4    2.312017647733628E-4    8.117199104238135E-5    2.8787511843006215E-5    1.0279598478348542E-5    3.6872987746593366E-6    1.3264993458811192E-6    4.780938295685138E-7    1.7251588746973008E-7    6.229666266632005E-8    q1:    5.62674207030434, 5.626742074589739, 5.626742063777632, 5.626742101012727, 5.626742037269133,     q:    5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638,     PageRank为:    5.626742067958352, 5.626742066427658, 5.626742070495978, 5.626742056768215, 5.626742079766638,   
</div> </div> 来自:http://blog.csdn.net/pelick/article/details/8847457