|
|
|
@@ -1,13 +1,13 @@ |
|
|
|
from concurrent.futures import ProcessPoolExecutor |
|
|
|
from os.path import abspath, dirname, join |
|
|
|
from threading import Lock |
|
|
|
from typing import Any, cast |
|
|
|
|
|
|
|
import gevent.threadpool # type: ignore |
|
|
|
from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer # type: ignore |
|
|
|
|
|
|
|
_tokenizer: Any = None |
|
|
|
_lock = Lock() |
|
|
|
_pool = gevent.threadpool.ThreadPool(1) |
|
|
|
_executor = ProcessPoolExecutor(max_workers=1) |
|
|
|
|
|
|
|
|
|
|
|
class GPT2Tokenizer: |
|
|
|
@@ -22,8 +22,8 @@ class GPT2Tokenizer: |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def get_num_tokens(text: str) -> int: |
|
|
|
future = _pool.spawn(GPT2Tokenizer._get_num_tokens_by_gpt2, text) |
|
|
|
result = future.get(block=True) |
|
|
|
future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) |
|
|
|
result = future.result() |
|
|
|
return cast(int, result) |
|
|
|
|
|
|
|
@staticmethod |