Вы не можете выбрать более 25 тем Темы должны начинаться с буквы или цифры, могут содержать дефисы(-) и должны содержать не более 35 символов.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from jina import Deployment
  2. from docarray import BaseDoc
  3. from jina import Executor, requests
  4. from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
  5. import argparse
  6. import torch
  7. class Prompt(BaseDoc):
  8. message: list[dict]
  9. gen_conf: dict
  10. class Generation(BaseDoc):
  11. text: str
  12. tokenizer = None
  13. model_name = ""
  14. class TokenStreamingExecutor(Executor):
  15. def __init__(self, **kwargs):
  16. super().__init__(**kwargs)
  17. self.model = AutoModelForCausalLM.from_pretrained(
  18. model_name, device_map="auto", torch_dtype="auto"
  19. )
  20. @requests(on="/chat")
  21. async def generate(self, doc: Prompt, **kwargs) -> Generation:
  22. text = tokenizer.apply_chat_template(
  23. doc.message,
  24. tokenize=False,
  25. )
  26. inputs = tokenizer([text], return_tensors="pt")
  27. generation_config = GenerationConfig(
  28. **doc.gen_conf,
  29. eos_token_id=tokenizer.eos_token_id,
  30. pad_token_id=tokenizer.eos_token_id
  31. )
  32. generated_ids = self.model.generate(
  33. inputs.input_ids, generation_config=generation_config
  34. )
  35. generated_ids = [
  36. output_ids[len(input_ids) :]
  37. for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
  38. ]
  39. response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  40. yield Generation(text=response)
  41. @requests(on="/stream")
  42. async def task(self, doc: Prompt, **kwargs) -> Generation:
  43. text = tokenizer.apply_chat_template(
  44. doc.message,
  45. tokenize=False,
  46. )
  47. input = tokenizer([text], return_tensors="pt")
  48. input_len = input["input_ids"].shape[1]
  49. max_new_tokens = 512
  50. if "max_new_tokens" in doc.gen_conf:
  51. max_new_tokens = doc.gen_conf.pop("max_new_tokens")
  52. generation_config = GenerationConfig(
  53. **doc.gen_conf,
  54. eos_token_id=tokenizer.eos_token_id,
  55. pad_token_id=tokenizer.eos_token_id
  56. )
  57. for _ in range(max_new_tokens):
  58. output = self.model.generate(
  59. **input, max_new_tokens=1, generation_config=generation_config
  60. )
  61. if output[0][-1] == tokenizer.eos_token_id:
  62. break
  63. yield Generation(
  64. text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
  65. )
  66. input = {
  67. "input_ids": output,
  68. "attention_mask": torch.ones(1, len(output[0])),
  69. }
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument("--model_name", type=str, help="Model name or path")
  73. parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
  74. args = parser.parse_args()
  75. model_name = args.model_name
  76. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  77. with Deployment(
  78. uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
  79. ) as dep:
  80. dep.block()