| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 | 
							- from jina import Deployment
 - from docarray import BaseDoc
 - from jina import Executor, requests
 - from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
 - import argparse
 - import torch
 - 
 - 
 - class Prompt(BaseDoc):
 -     message: list[dict]
 -     gen_conf: dict
 - 
 - 
 - class Generation(BaseDoc):
 -     text: str
 - 
 - 
 - tokenizer = None
 - model_name = ""
 - 
 - 
 - class TokenStreamingExecutor(Executor):
 -     def __init__(self, **kwargs):
 -         super().__init__(**kwargs)
 -         self.model = AutoModelForCausalLM.from_pretrained(
 -             model_name, device_map="auto", torch_dtype="auto"
 -         )
 - 
 -     @requests(on="/chat")
 -     async def generate(self, doc: Prompt, **kwargs) -> Generation:
 -         text = tokenizer.apply_chat_template(
 -             doc.message,
 -             tokenize=False,
 -         )
 -         inputs = tokenizer([text], return_tensors="pt")
 -         generation_config = GenerationConfig(
 -             **doc.gen_conf,
 -             eos_token_id=tokenizer.eos_token_id,
 -             pad_token_id=tokenizer.eos_token_id
 -         )
 -         generated_ids = self.model.generate(
 -             inputs.input_ids, generation_config=generation_config
 -         )
 -         generated_ids = [
 -             output_ids[len(input_ids) :]
 -             for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
 -         ]
 - 
 -         response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 -         yield Generation(text=response)
 - 
 -     @requests(on="/stream")
 -     async def task(self, doc: Prompt, **kwargs) -> Generation:
 -         text = tokenizer.apply_chat_template(
 -             doc.message,
 -             tokenize=False,
 -         )
 -         input = tokenizer([text], return_tensors="pt")
 -         input_len = input["input_ids"].shape[1]
 -         max_new_tokens = 512
 -         if "max_new_tokens" in doc.gen_conf:
 -             max_new_tokens = doc.gen_conf.pop("max_new_tokens")
 -         generation_config = GenerationConfig(
 -             **doc.gen_conf,
 -             eos_token_id=tokenizer.eos_token_id,
 -             pad_token_id=tokenizer.eos_token_id
 -         )
 -         for _ in range(max_new_tokens):
 -             output = self.model.generate(
 -                 **input, max_new_tokens=1, generation_config=generation_config
 -             )
 -             if output[0][-1] == tokenizer.eos_token_id:
 -                 break
 -             yield Generation(
 -                 text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
 -             )
 -             input = {
 -                 "input_ids": output,
 -                 "attention_mask": torch.ones(1, len(output[0])),
 -             }
 - 
 - 
 - if __name__ == "__main__":
 -     parser = argparse.ArgumentParser()
 -     parser.add_argument("--model_name", type=str, help="Model name or path")
 -     parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
 -     args = parser.parse_args()
 -     model_name = args.model_name
 -     tokenizer = AutoTokenizer.from_pretrained(args.model_name)
 -     with Deployment(
 -         uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
 -     ) as dep:
 -         dep.block()
 
 
  |