瀏覽代碼

accelerate tokenize (#3244)

### What problem does this PR solve?


### Type of change

- [x] Performance Improvement
tags/v0.14.0
Kevin Hu 1 年之前
父節點
當前提交
fbcc0bb408
沒有連結到貢獻者的電子郵件帳戶。
共有 1 個檔案被更改,包括 40 行新增25 行删除
  1. 40
    25
      rag/nlp/rag_tokenizer.py

+ 40
- 25
rag/nlp/rag_tokenizer.py 查看文件

@@ -281,34 +281,49 @@ class RagTokenizer:
print("[FW]", tks, s)
print("[BW]", tks1, s1)

diff = [0 for _ in range(max(len(tks1), len(tks)))]
for i in range(min(len(tks1), len(tks))):
if tks[i] != tks1[i]:
diff[i] = 1

if s1 > s:
tks = tks1

i = 0
while i < len(tks):
s = i
while s < len(tks) and diff[s] == 0:
s += 1
if s == len(tks):
res.append(" ".join(tks[i:]))
break
if s > i:
res.append(" ".join(tks[i:s]))

e = s
while e < len(tks) and e - s < 5 and diff[e] == 1:
e += 1

i, j, _i, _j = 0, 0, 0, 0
same = 0
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
if same > 0: res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1

while i < len(tks1) and j < len(tks):
tk1, tk = "".join(tks1[_i:i]), "".join(tks[_j:j])
if tk1 != tk:
if len(tk1) > len(tk):
j += 1
else:
i += 1
continue

if tks1[i] != tks[j]:
i += 1
j += 1
continue
# backward tokens from_i to i are different from forward tokens from _j to j.
tkslist = []
self.dfs_("".join(tks[s:e + 1]), 0, [], tkslist)
self.dfs_("".join(tks[_j:j]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))

i = e + 1
same = 1
while i + same < len(tks1) and j + same < len(tks) and tks1[i + same] == tks[j + same]:
same += 1
res.append(" ".join(tks[j: j + same]))
_i = i + same
_j = j + same
j = _j + 1
i = _i + 1

if _i < len(tks1):
assert _j < len(tks)
assert "".join(tks1[_i:]) == "".join(tks[_j:])
tkslist = []
self.dfs_("".join(tks[_j:]), 0, [], tkslist)
res.append(" ".join(self.sortTks_(tkslist)[0][0]))

res = " ".join(self.english_normalize_(res))
if self.DEBUG:

Loading…
取消
儲存