選択できるのは25トピックまでです。 トピックは、先頭が英数字で、英数字とダッシュ('-')を使用した35文字以内のものにしてください。

jina_server.py 3.0KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. from jina import Deployment
  2. from docarray import BaseDoc
  3. from jina import Executor, requests
  4. from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
  5. import argparse
  6. import torch
  7. class Prompt(BaseDoc):
  8. message: list[dict]
  9. gen_conf: dict
  10. class Generation(BaseDoc):
  11. text: str
  12. tokenizer = None
  13. model_name = ""
  14. class TokenStreamingExecutor(Executor):
  15. def __init__(self, **kwargs):
  16. super().__init__(**kwargs)
  17. self.model = AutoModelForCausalLM.from_pretrained(
  18. model_name, device_map="auto", torch_dtype="auto"
  19. )
  20. @requests(on="/chat")
  21. async def generate(self, doc: Prompt, **kwargs) -> Generation:
  22. text = tokenizer.apply_chat_template(
  23. doc.message,
  24. tokenize=False,
  25. )
  26. inputs = tokenizer([text], return_tensors="pt")
  27. generation_config = GenerationConfig(
  28. **doc.gen_conf,
  29. eos_token_id=tokenizer.eos_token_id,
  30. pad_token_id=tokenizer.eos_token_id
  31. )
  32. generated_ids = self.model.generate(
  33. inputs.input_ids, generation_config=generation_config
  34. )
  35. generated_ids = [
  36. output_ids[len(input_ids) :]
  37. for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
  38. ]
  39. response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  40. yield Generation(text=response)
  41. @requests(on="/stream")
  42. async def task(self, doc: Prompt, **kwargs) -> Generation:
  43. text = tokenizer.apply_chat_template(
  44. doc.message,
  45. tokenize=False,
  46. )
  47. input = tokenizer([text], return_tensors="pt")
  48. input_len = input["input_ids"].shape[1]
  49. max_new_tokens = 512
  50. if "max_new_tokens" in doc.gen_conf:
  51. max_new_tokens = doc.gen_conf.pop("max_new_tokens")
  52. generation_config = GenerationConfig(
  53. **doc.gen_conf,
  54. eos_token_id=tokenizer.eos_token_id,
  55. pad_token_id=tokenizer.eos_token_id
  56. )
  57. for _ in range(max_new_tokens):
  58. output = self.model.generate(
  59. **input, max_new_tokens=1, generation_config=generation_config
  60. )
  61. if output[0][-1] == tokenizer.eos_token_id:
  62. break
  63. yield Generation(
  64. text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
  65. )
  66. input = {
  67. "input_ids": output,
  68. "attention_mask": torch.ones(1, len(output[0])),
  69. }
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument("--model_name", type=str, help="Model name or path")
  73. parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
  74. args = parser.parse_args()
  75. model_name = args.model_name
  76. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  77. with Deployment(
  78. uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
  79. ) as dep:
  80. dep.block()