| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- from jina import Deployment
- from docarray import BaseDoc
- from jina import Executor, requests
- from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
- import argparse
- import torch
-
-
- class Prompt(BaseDoc):
- message: list[dict]
- gen_conf: dict
-
-
- class Generation(BaseDoc):
- text: str
-
-
- tokenizer = None
- model_name = ""
-
-
- class TokenStreamingExecutor(Executor):
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
- self.model = AutoModelForCausalLM.from_pretrained(
- model_name, device_map="auto", torch_dtype="auto"
- )
-
- @requests(on="/chat")
- async def generate(self, doc: Prompt, **kwargs) -> Generation:
- text = tokenizer.apply_chat_template(
- doc.message,
- tokenize=False,
- )
- inputs = tokenizer([text], return_tensors="pt")
- generation_config = GenerationConfig(
- **doc.gen_conf,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.eos_token_id
- )
- generated_ids = self.model.generate(
- inputs.input_ids, generation_config=generation_config
- )
- generated_ids = [
- output_ids[len(input_ids) :]
- for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
- ]
-
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
- yield Generation(text=response)
-
- @requests(on="/stream")
- async def task(self, doc: Prompt, **kwargs) -> Generation:
- text = tokenizer.apply_chat_template(
- doc.message,
- tokenize=False,
- )
- input = tokenizer([text], return_tensors="pt")
- input_len = input["input_ids"].shape[1]
- max_new_tokens = 512
- if "max_new_tokens" in doc.gen_conf:
- max_new_tokens = doc.gen_conf.pop("max_new_tokens")
- generation_config = GenerationConfig(
- **doc.gen_conf,
- eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.eos_token_id
- )
- for _ in range(max_new_tokens):
- output = self.model.generate(
- **input, max_new_tokens=1, generation_config=generation_config
- )
- if output[0][-1] == tokenizer.eos_token_id:
- break
- yield Generation(
- text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
- )
- input = {
- "input_ids": output,
- "attention_mask": torch.ones(1, len(output[0])),
- }
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser()
- parser.add_argument("--model_name", type=str, help="Model name or path")
- parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
- args = parser.parse_args()
- model_name = args.model_name
- tokenizer = AutoTokenizer.from_pretrained(args.model_name)
- with Deployment(
- uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
- ) as dep:
- dep.block()
|