You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

jina_server.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #
  2. # Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. from jina import Deployment
  17. from docarray import BaseDoc
  18. from jina import Executor, requests
  19. from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
  20. import argparse
  21. import torch
  22. class Prompt(BaseDoc):
  23. message: list[dict]
  24. gen_conf: dict
  25. class Generation(BaseDoc):
  26. text: str
  27. tokenizer = None
  28. model_name = ""
  29. class TokenStreamingExecutor(Executor):
  30. def __init__(self, **kwargs):
  31. super().__init__(**kwargs)
  32. self.model = AutoModelForCausalLM.from_pretrained(
  33. model_name, device_map="auto", torch_dtype="auto"
  34. )
  35. @requests(on="/chat")
  36. async def generate(self, doc: Prompt, **kwargs) -> Generation:
  37. text = tokenizer.apply_chat_template(
  38. doc.message,
  39. tokenize=False,
  40. )
  41. inputs = tokenizer([text], return_tensors="pt")
  42. generation_config = GenerationConfig(
  43. **doc.gen_conf,
  44. eos_token_id=tokenizer.eos_token_id,
  45. pad_token_id=tokenizer.eos_token_id
  46. )
  47. generated_ids = self.model.generate(
  48. inputs.input_ids, generation_config=generation_config
  49. )
  50. generated_ids = [
  51. output_ids[len(input_ids) :]
  52. for input_ids, output_ids in zip(inputs.input_ids, generated_ids)
  53. ]
  54. response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
  55. yield Generation(text=response)
  56. @requests(on="/stream")
  57. async def task(self, doc: Prompt, **kwargs) -> Generation:
  58. text = tokenizer.apply_chat_template(
  59. doc.message,
  60. tokenize=False,
  61. )
  62. input = tokenizer([text], return_tensors="pt")
  63. input_len = input["input_ids"].shape[1]
  64. max_new_tokens = 512
  65. if "max_new_tokens" in doc.gen_conf:
  66. max_new_tokens = doc.gen_conf.pop("max_new_tokens")
  67. generation_config = GenerationConfig(
  68. **doc.gen_conf,
  69. eos_token_id=tokenizer.eos_token_id,
  70. pad_token_id=tokenizer.eos_token_id
  71. )
  72. for _ in range(max_new_tokens):
  73. output = self.model.generate(
  74. **input, max_new_tokens=1, generation_config=generation_config
  75. )
  76. if output[0][-1] == tokenizer.eos_token_id:
  77. break
  78. yield Generation(
  79. text=tokenizer.decode(output[0][input_len:], skip_special_tokens=True)
  80. )
  81. input = {
  82. "input_ids": output,
  83. "attention_mask": torch.ones(1, len(output[0])),
  84. }
  85. if __name__ == "__main__":
  86. parser = argparse.ArgumentParser()
  87. parser.add_argument("--model_name", type=str, help="Model name or path")
  88. parser.add_argument("--port", default=12345, type=int, help="Jina serving port")
  89. args = parser.parse_args()
  90. model_name = args.model_name
  91. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  92. with Deployment(
  93. uses=TokenStreamingExecutor, port=args.port, protocol="grpc"
  94. ) as dep:
  95. dep.block()