|
|
|
@@ -116,55 +116,86 @@ class RagTokenizer: |
|
|
|
def _tradi2simp(self, line): |
|
|
|
return HanziConv.toSimplified(line) |
|
|
|
|
|
|
|
def dfs_(self, chars, s, preTks, tkslist): |
|
|
|
def dfs_(self, chars, s, preTks, tkslist, _depth=0, _memo=None): |
|
|
|
if _memo is None: |
|
|
|
_memo = {} |
|
|
|
MAX_DEPTH = 10 |
|
|
|
if _depth > MAX_DEPTH: |
|
|
|
if s < len(chars): |
|
|
|
copy_pretks = copy.deepcopy(preTks) |
|
|
|
remaining = "".join(chars[s:]) |
|
|
|
copy_pretks.append((remaining, (-12, ''))) |
|
|
|
tkslist.append(copy_pretks) |
|
|
|
return s |
|
|
|
|
|
|
|
state_key = (s, tuple(tk[0] for tk in preTks)) if preTks else (s, None) |
|
|
|
if state_key in _memo: |
|
|
|
return _memo[state_key] |
|
|
|
|
|
|
|
res = s |
|
|
|
if len(tkslist) >= 2048: |
|
|
|
return res |
|
|
|
# if s > MAX_L or s>= len(chars): |
|
|
|
if s >= len(chars): |
|
|
|
tkslist.append(preTks) |
|
|
|
return res |
|
|
|
|
|
|
|
# pruning |
|
|
|
_memo[state_key] = s |
|
|
|
return s |
|
|
|
if s < len(chars) - 4: |
|
|
|
is_repetitive = True |
|
|
|
char_to_check = chars[s] |
|
|
|
for i in range(1, 5): |
|
|
|
if s + i >= len(chars) or chars[s + i] != char_to_check: |
|
|
|
is_repetitive = False |
|
|
|
break |
|
|
|
if is_repetitive: |
|
|
|
end = s |
|
|
|
while end < len(chars) and chars[end] == char_to_check: |
|
|
|
end += 1 |
|
|
|
mid = s + min(10, end - s) |
|
|
|
t = "".join(chars[s:mid]) |
|
|
|
k = self.key_(t) |
|
|
|
copy_pretks = copy.deepcopy(preTks) |
|
|
|
if k in self.trie_: |
|
|
|
copy_pretks.append((t, self.trie_[k])) |
|
|
|
else: |
|
|
|
copy_pretks.append((t, (-12, ''))) |
|
|
|
next_res = self.dfs_(chars, mid, copy_pretks, tkslist, _depth + 1, _memo) |
|
|
|
res = max(res, next_res) |
|
|
|
_memo[state_key] = res |
|
|
|
return res |
|
|
|
|
|
|
|
S = s + 1 |
|
|
|
if s + 2 <= len(chars): |
|
|
|
t1, t2 = "".join(chars[s:s + 1]), "".join(chars[s:s + 2]) |
|
|
|
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix( |
|
|
|
self.key_(t2)): |
|
|
|
t1 = "".join(chars[s:s + 1]) |
|
|
|
t2 = "".join(chars[s:s + 2]) |
|
|
|
if self.trie_.has_keys_with_prefix(self.key_(t1)) and not self.trie_.has_keys_with_prefix(self.key_(t2)): |
|
|
|
S = s + 2 |
|
|
|
if len(preTks) > 2 and len( |
|
|
|
preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1: |
|
|
|
if len(preTks) > 2 and len(preTks[-1][0]) == 1 and len(preTks[-2][0]) == 1 and len(preTks[-3][0]) == 1: |
|
|
|
t1 = preTks[-1][0] + "".join(chars[s:s + 1]) |
|
|
|
if self.trie_.has_keys_with_prefix(self.key_(t1)): |
|
|
|
S = s + 2 |
|
|
|
|
|
|
|
################ |
|
|
|
|
|
|
|
for e in range(S, len(chars) + 1): |
|
|
|
t = "".join(chars[s:e]) |
|
|
|
k = self.key_(t) |
|
|
|
|
|
|
|
if e > s + 1 and not self.trie_.has_keys_with_prefix(k): |
|
|
|
break |
|
|
|
|
|
|
|
if k in self.trie_: |
|
|
|
pretks = copy.deepcopy(preTks) |
|
|
|
if k in self.trie_: |
|
|
|
pretks.append((t, self.trie_[k])) |
|
|
|
else: |
|
|
|
pretks.append((t, (-12, ''))) |
|
|
|
res = max(res, self.dfs_(chars, e, pretks, tkslist)) |
|
|
|
|
|
|
|
pretks.append((t, self.trie_[k])) |
|
|
|
res = max(res, self.dfs_(chars, e, pretks, tkslist, _depth + 1, _memo)) |
|
|
|
|
|
|
|
if res > s: |
|
|
|
_memo[state_key] = res |
|
|
|
return res |
|
|
|
|
|
|
|
|
|
|
|
t = "".join(chars[s:s + 1]) |
|
|
|
k = self.key_(t) |
|
|
|
copy_pretks = copy.deepcopy(preTks) |
|
|
|
if k in self.trie_: |
|
|
|
preTks.append((t, self.trie_[k])) |
|
|
|
copy_pretks.append((t, self.trie_[k])) |
|
|
|
else: |
|
|
|
preTks.append((t, (-12, ''))) |
|
|
|
|
|
|
|
return self.dfs_(chars, s + 1, preTks, tkslist) |
|
|
|
copy_pretks.append((t, (-12, ''))) |
|
|
|
result = self.dfs_(chars, s + 1, copy_pretks, tkslist, _depth + 1, _memo) |
|
|
|
_memo[state_key] = result |
|
|
|
return result |
|
|
|
|
|
|
|
def freq(self, tk): |
|
|
|
k = self.key_(tk) |