Java双数组Trie树的实现方案总结

jopen 9年前

传统的Trie实现简单,但是占用的空间实在是难以接受,特别是当字符集不仅限于英文26个字符的时候,爆炸起来的空间根本无法接受。

双数组Trie就是优化了空间的Trie树,原理本文就不讲了,请参考An Efficient Implementation of Trie Structures,本程序的编写也是参考这篇论文的。

关于几点论文没有提及的细节和与论文不一一致的实现:

1.对于插入字符串,如果有一个字符串是另一个字符串的子串的话,我是将结束符也作为一条边,产生一个新的结点,这个结点新节点的Base我置为0

所以一个字符串结束也有2中情况:一个是Base值为负,存储剩余字符(可能只有一个结束符)到Tail数组;另一个是Base为0。

所以在查询的时候要考虑一下这两种情况

2.对于第一种冲突(论文中的Case 3),可能要将Tail中的字符串取出一部分,作为边放到索引中。论文是使用将尾串左移的方式,我的方式直接修改Base值,而不是移动尾串。

下面是java实现的代码,可以处理相同字符串插入,子串的插入等情况

/*     * Name:   Double Array Trie     * Author: Yaguang Ding     * Mail: dingyaguang117@gmail.com     * Blog: blog.csdn.net/dingyaguang117     * Date:   2012/5/21     * Note: a word ends may be either of these two case:     * 1. Base[cur_p] == pos  ( pos<0 and Tail[-pos] == 'END_CHAR' )     * 2. Check[Base[cur_p] + Code('END_CHAR')] ==  cur_p     */     import java.util.ArrayList;    import java.util.HashMap;    import java.util.Map;    import java.util.Arrays;      public class DoubleArrayTrie {        final char END_CHAR = '\0';        final int DEFAULT_LEN = 1024;        int Base[]  = new int [DEFAULT_LEN];        int Check[] = new int [DEFAULT_LEN];        char Tail[] = new char [DEFAULT_LEN];        int Pos = 1;        Map<Character ,Integer> CharMap = new HashMap<Character,Integer>();        ArrayList<Character> CharList = new ArrayList<Character>();          public DoubleArrayTrie()        {            Base[1] = 1;              CharMap.put(END_CHAR,1);            CharList.add(END_CHAR);            CharList.add(END_CHAR);            for(int i=0;i<26;++i)            {                CharMap.put((char)('a'+i),CharMap.size()+1);                CharList.add((char)('a'+i));            }          }        private void Extend_Array()        {            Base = Arrays.copyOf(Base, Base.length*2);            Check = Arrays.copyOf(Check, Check.length*2);        }          private void Extend_Tail()        {            Tail = Arrays.copyOf(Tail, Tail.length*2);        }          private int GetCharCode(char c)        {            if (!CharMap.containsKey(c))            {                CharMap.put(c,CharMap.size()+1);                CharList.add(c);            }            return CharMap.get(c);        }        private int CopyToTailArray(String s,int p)        {            int _Pos = Pos;            while(s.length()-p+1 > Tail.length-Pos)            {                Extend_Tail();            }            for(int i=p; i<s.length();++i)            {                Tail[_Pos] = s.charAt(i);                _Pos++;            }            return _Pos;        }          private int x_check(Integer []set)        {            for(int i=1; ; ++i)            {                boolean flag = true;                for(int j=0;j<set.length;++j)                {                    int cur_p = i+set[j];                    if(cur_p>= Base.length) Extend_Array();                    if(Base[cur_p]!= 0 || Check[cur_p]!= 0)                    {                        flag = false;                        break;                    }                }                if (flag) return i;            }        }          private ArrayList<Integer> GetChildList(int p)        {            ArrayList<Integer> ret = new ArrayList<Integer>();            for(int i=1; i<=CharMap.size();++i)            {                if(Base[p]+i >= Check.length) break;                if(Check[Base[p]+i] == p)                {                    ret.add(i);                }            }            return ret;        }          private boolean TailContainString(int start,String s2)        {            for(int i=0;i<s2.length();++i)            {                if(s2.charAt(i) != Tail[i+start]) return false;            }              return true;        }        private boolean TailMatchString(int start,String s2)        {            s2 += END_CHAR;            for(int i=0;i<s2.length();++i)            {                if(s2.charAt(i) != Tail[i+start]) return false;            }            return true;        }          public void Insert(String s) throws Exception        {            s += END_CHAR;              int pre_p = 1;            int cur_p;            for(int i=0; i<s.length(); ++i)            {                //获取状态位置                cur_p = Base[pre_p]+GetCharCode(s.charAt(i));                //如果长度超过现有,拓展数组                if (cur_p >= Base.length) Extend_Array();                  //空闲状态                if(Base[cur_p] == 0 && Check[cur_p] == 0)                {                    Base[cur_p] = -Pos;                    Check[cur_p] = pre_p;                    Pos = CopyToTailArray(s,i+1);                    break;                }else               //已存在状态                if(Base[cur_p] > 0 && Check[cur_p] == pre_p)                {                    pre_p = cur_p;                    continue;                }else               //冲突 1:遇到 Base[cur_p]小于0的,即遇到一个被压缩存到Tail中的字符串                if(Base[cur_p] < 0 && Check[cur_p] == pre_p)                {                    int head = -Base[cur_p];                      if(s.charAt(i+1)== END_CHAR && Tail[head]==END_CHAR)    //插入重复字符串                    {                        break;                    }                      //公共字母的情况,因为上一个判断已经排除了结束符,所以一定是2个都不是结束符                    if (Tail[head] == s.charAt(i+1))                    {                        int avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1))});                        Base[cur_p] = avail_base;                          Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;                        Base[avail_base+GetCharCode(s.charAt(i+1))] = -(head+1);                        pre_p = cur_p;                        continue;                    }                    else                   {                        //2个字母不相同的情况,可能有一个为结束符                        int avail_base ;                        avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1)),GetCharCode(Tail[head])});                          Base[cur_p] = avail_base;                          Check[avail_base+GetCharCode(Tail[head])] = cur_p;                        Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;                          //Tail 为END_FLAG 的情况                        if(Tail[head] == END_CHAR)                            Base[avail_base+GetCharCode(Tail[head])] = 0;                        else                           Base[avail_base+GetCharCode(Tail[head])] = -(head+1);                        if(s.charAt(i+1) == END_CHAR)                             Base[avail_base+GetCharCode(s.charAt(i+1))] = 0;                        else                           Base[avail_base+GetCharCode(s.charAt(i+1))] = -Pos;                          Pos = CopyToTailArray(s,i+2);                        break;                    }                }else               //冲突2:当前结点已经被占用,需要调整pre的base                if(Check[cur_p] != pre_p)                {                    ArrayList<Integer> list1 = GetChildList(pre_p);                    int toBeAdjust;                    ArrayList<Integer> list = null;                    if(true)                    {                        toBeAdjust = pre_p;                        list = list1;                    }                      int origin_base = Base[toBeAdjust];                    list.add(GetCharCode(s.charAt(i)));                    int avail_base = x_check((Integer[])list.toArray(new Integer[list.size()]));                    list.remove(list.size()-1);                      Base[toBeAdjust] = avail_base;                    for(int j=0; j<list.size(); ++j)                    {                        //BUG                         int tmp1 = origin_base + list.get(j);                        int tmp2 = avail_base + list.get(j);                          Base[tmp2] = Base[tmp1];                        Check[tmp2] = Check[tmp1];                          //有后续                        if(Base[tmp1] > 0)                        {                            ArrayList<Integer> subsequence = GetChildList(tmp1);                            for(int k=0; k<subsequence.size(); ++k)                            {                                Check[Base[tmp1]+subsequence.get(k)] = tmp2;                            }                        }                          Base[tmp1] = 0;                        Check[tmp1] = 0;                    }                      //更新新的cur_p                    cur_p = Base[pre_p]+GetCharCode(s.charAt(i));                      if(s.charAt(i) == END_CHAR)                        Base[cur_p] = 0;                    else                       Base[cur_p] = -Pos;                    Check[cur_p] = pre_p;                    Pos = CopyToTailArray(s,i+1);                    break;                }            }        }          public boolean Exists(String word)        {            int pre_p = 1;            int cur_p = 0;              for(int i=0;i<word.length();++i)            {                cur_p = Base[pre_p]+GetCharCode(word.charAt(i));                if(Check[cur_p] != pre_p) return false;                if(Base[cur_p] < 0)                {                    if(TailMatchString(-Base[cur_p],word.substring(i+1)))                        return true;                    return false;                }                pre_p = cur_p;            }            if(Check[Base[cur_p]+GetCharCode(END_CHAR)] == cur_p)                return true;            return false;        }          //内部函数,返回匹配单词的最靠后的Base index,        class FindStruct        {            int p;            String prefix="";        }        private FindStruct Find(String word)        {            int pre_p = 1;            int cur_p = 0;            FindStruct fs = new FindStruct();            for(int i=0;i<word.length();++i)            {                // BUG                fs.prefix += word.charAt(i);                cur_p = Base[pre_p]+GetCharCode(word.charAt(i));                if(Check[cur_p] != pre_p)                {                    fs.p = -1;                    return fs;                }                if(Base[cur_p] < 0)                {                    if(TailContainString(-Base[cur_p],word.substring(i+1)))                    {                        fs.p = cur_p;                        return fs;                    }                    fs.p = -1;                    return fs;                }                pre_p = cur_p;            }            fs.p =  cur_p;            return fs;        }          public ArrayList<String> GetAllChildWord(int index)        {            ArrayList<String> result = new ArrayList<String>();            if(Base[index] == 0)            {                result.add("");                return result;            }            if(Base[index] < 0)            {                String r="";                for(int i=-Base[index];Tail[i]!=END_CHAR;++i)                {                    r+= Tail[i];                }                result.add(r);                return result;            }            for(int i=1;i<=CharMap.size();++i)            {                if(Check[Base[index]+i] == index)                {                    for(String s:GetAllChildWord(Base[index]+i))                    {                        result.add(CharList.get(i)+s);                    }                    //result.addAll(GetAllChildWord(Base[index]+i));                }            }            return result;        }          public ArrayList<String> FindAllWords(String word)        {            ArrayList<String> result = new ArrayList<String>();            String prefix = "";            FindStruct fs = Find(word);            int p = fs.p;            if (p == -1) return result;            if(Base[p]<0)            {                String r="";                for(int i=-Base[p];Tail[i]!=END_CHAR;++i)                {                    r+= Tail[i];                }                result.add(fs.prefix+r);                return result;            }              if(Base[p] > 0)            {                ArrayList<String> r =  GetAllChildWord(p);                for(int i=0;i<r.size();++i)                {                    r.set(i, fs.prefix+r.get(i));                }                return r;            }              return result;        }      }

测试代码:

import java.io.BufferedReader;    import java.io.FileInputStream;    import java.io.IOException;    import java.io.InputStream;    import java.io.InputStreamReader;    import java.util.ArrayList;    import java.util.Scanner;      import javax.xml.crypto.Data;      public class Main {          public static void main(String[] args) throws Exception {            ArrayList<String> words = new ArrayList<String>();            BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream("E:/兔子的试验学习中心[课内]/ACM大赛/ACM第四届校赛/E命令提示/words3.dic")));            String s;            int num = 0;            while((s=reader.readLine()) != null)            {                words.add(s);                num ++;            }            DoubleArrayTrie dat = new DoubleArrayTrie();              for(String word: words)            {                dat.Insert(word);            }              System.out.println(dat.Base.length);            System.out.println(dat.Tail.length);              Scanner sc = new Scanner(System.in);            while(sc.hasNext())            {                String word = sc.next();                System.out.println(dat.Exists(word));                System.out.println(dat.FindAllWords(word));            }          }      }

下面是测试结果,构造6W英文单词的DAT,大概需要20秒

我增长数组的时候是每次长度增加到2倍,初始1024

Base和Check数组的长度为131072

Tail的长度为262144

c194afffc082d9786ef7ad40e7e3281c.png
来源:dingyaguang117