You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rpc_server.py 5.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #
  2. # Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import argparse
  17. import pickle
  18. import random
  19. import time
  20. from copy import deepcopy
  21. from multiprocessing.connection import Listener
  22. from threading import Thread
  23. from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
  24. def torch_gc():
  25. try:
  26. import torch
  27. if torch.cuda.is_available():
  28. # with torch.cuda.device(DEVICE):
  29. torch.cuda.empty_cache()
  30. torch.cuda.ipc_collect()
  31. elif torch.backends.mps.is_available():
  32. try:
  33. from torch.mps import empty_cache
  34. empty_cache()
  35. except Exception as e:
  36. pass
  37. except Exception:
  38. pass
  39. class RPCHandler:
  40. def __init__(self):
  41. self._functions = {}
  42. def register_function(self, func):
  43. self._functions[func.__name__] = func
  44. def handle_connection(self, connection):
  45. try:
  46. while True:
  47. # Receive a message
  48. func_name, args, kwargs = pickle.loads(connection.recv())
  49. # Run the RPC and send a response
  50. try:
  51. r = self._functions[func_name](*args, **kwargs)
  52. connection.send(pickle.dumps(r))
  53. except Exception as e:
  54. connection.send(pickle.dumps(e))
  55. except EOFError:
  56. pass
  57. def rpc_server(hdlr, address, authkey):
  58. sock = Listener(address, authkey=authkey)
  59. while True:
  60. try:
  61. client = sock.accept()
  62. t = Thread(target=hdlr.handle_connection, args=(client,))
  63. t.daemon = True
  64. t.start()
  65. except Exception as e:
  66. print("【EXCEPTION】:", str(e))
  67. models = []
  68. tokenizer = None
  69. def chat(messages, gen_conf):
  70. global tokenizer
  71. model = Model()
  72. try:
  73. torch_gc()
  74. conf = {
  75. "max_new_tokens": int(
  76. gen_conf.get(
  77. "max_tokens", 256)), "temperature": float(
  78. gen_conf.get(
  79. "temperature", 0.1))}
  80. print(messages, conf)
  81. text = tokenizer.apply_chat_template(
  82. messages,
  83. tokenize=False,
  84. add_generation_prompt=True
  85. )
  86. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  87. generated_ids = model.generate(
  88. model_inputs.input_ids,
  89. **conf
  90. )
  91. generated_ids = [
  92. output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  93. ]
  94. return tokenizer.batch_decode(
  95. generated_ids, skip_special_tokens=True)[0]
  96. except Exception as e:
  97. return str(e)
  98. def chat_streamly(messages, gen_conf):
  99. global tokenizer
  100. model = Model()
  101. try:
  102. torch_gc()
  103. conf = deepcopy(gen_conf)
  104. print(messages, conf)
  105. text = tokenizer.apply_chat_template(
  106. messages,
  107. tokenize=False,
  108. add_generation_prompt=True
  109. )
  110. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  111. streamer = TextStreamer(tokenizer)
  112. conf["inputs"] = model_inputs.input_ids
  113. conf["streamer"] = streamer
  114. conf["max_new_tokens"] = conf["max_tokens"]
  115. del conf["max_tokens"]
  116. thread = Thread(target=model.generate, kwargs=conf)
  117. thread.start()
  118. for _, new_text in enumerate(streamer):
  119. yield new_text
  120. except Exception as e:
  121. yield "**ERROR**: " + str(e)
  122. def Model():
  123. global models
  124. random.seed(time.time())
  125. return random.choice(models)
  126. if __name__ == "__main__":
  127. parser = argparse.ArgumentParser()
  128. parser.add_argument("--model_name", type=str, help="Model name")
  129. parser.add_argument(
  130. "--port",
  131. default=7860,
  132. type=int,
  133. help="RPC serving port")
  134. args = parser.parse_args()
  135. handler = RPCHandler()
  136. handler.register_function(chat)
  137. handler.register_function(chat_streamly)
  138. models = []
  139. for _ in range(1):
  140. m = AutoModelForCausalLM.from_pretrained(args.model_name,
  141. device_map="auto",
  142. torch_dtype='auto')
  143. models.append(m)
  144. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  145. # Run the server
  146. rpc_server(handler, ('0.0.0.0', args.port),
  147. authkey=b'infiniflow-token4kevinhu')