You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. import argparse
  2. import pickle
  3. import random
  4. import time
  5. from copy import deepcopy
  6. from multiprocessing.connection import Listener
  7. from threading import Thread
  8. from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
  9. def torch_gc():
  10. try:
  11. import torch
  12. if torch.cuda.is_available():
  13. # with torch.cuda.device(DEVICE):
  14. torch.cuda.empty_cache()
  15. torch.cuda.ipc_collect()
  16. elif torch.backends.mps.is_available():
  17. try:
  18. from torch.mps import empty_cache
  19. empty_cache()
  20. except Exception as e:
  21. pass
  22. except Exception:
  23. pass
  24. class RPCHandler:
  25. def __init__(self):
  26. self._functions = {}
  27. def register_function(self, func):
  28. self._functions[func.__name__] = func
  29. def handle_connection(self, connection):
  30. try:
  31. while True:
  32. # Receive a message
  33. func_name, args, kwargs = pickle.loads(connection.recv())
  34. # Run the RPC and send a response
  35. try:
  36. r = self._functions[func_name](*args, **kwargs)
  37. connection.send(pickle.dumps(r))
  38. except Exception as e:
  39. connection.send(pickle.dumps(e))
  40. except EOFError:
  41. pass
  42. def rpc_server(hdlr, address, authkey):
  43. sock = Listener(address, authkey=authkey)
  44. while True:
  45. try:
  46. client = sock.accept()
  47. t = Thread(target=hdlr.handle_connection, args=(client,))
  48. t.daemon = True
  49. t.start()
  50. except Exception as e:
  51. print("【EXCEPTION】:", str(e))
  52. models = []
  53. tokenizer = None
  54. def chat(messages, gen_conf):
  55. global tokenizer
  56. model = Model()
  57. try:
  58. torch_gc()
  59. conf = {
  60. "max_new_tokens": int(
  61. gen_conf.get(
  62. "max_tokens", 256)), "temperature": float(
  63. gen_conf.get(
  64. "temperature", 0.1))}
  65. print(messages, conf)
  66. text = tokenizer.apply_chat_template(
  67. messages,
  68. tokenize=False,
  69. add_generation_prompt=True
  70. )
  71. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  72. generated_ids = model.generate(
  73. model_inputs.input_ids,
  74. **conf
  75. )
  76. generated_ids = [
  77. output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  78. ]
  79. return tokenizer.batch_decode(
  80. generated_ids, skip_special_tokens=True)[0]
  81. except Exception as e:
  82. return str(e)
  83. def chat_streamly(messages, gen_conf):
  84. global tokenizer
  85. model = Model()
  86. try:
  87. torch_gc()
  88. conf = deepcopy(gen_conf)
  89. print(messages, conf)
  90. text = tokenizer.apply_chat_template(
  91. messages,
  92. tokenize=False,
  93. add_generation_prompt=True
  94. )
  95. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  96. streamer = TextStreamer(tokenizer)
  97. conf["inputs"] = model_inputs.input_ids
  98. conf["streamer"] = streamer
  99. conf["max_new_tokens"] = conf["max_tokens"]
  100. del conf["max_tokens"]
  101. thread = Thread(target=model.generate, kwargs=conf)
  102. thread.start()
  103. for _, new_text in enumerate(streamer):
  104. yield new_text
  105. except Exception as e:
  106. yield "**ERROR**: " + str(e)
  107. def Model():
  108. global models
  109. random.seed(time.time())
  110. return random.choice(models)
  111. if __name__ == "__main__":
  112. parser = argparse.ArgumentParser()
  113. parser.add_argument("--model_name", type=str, help="Model name")
  114. parser.add_argument(
  115. "--port",
  116. default=7860,
  117. type=int,
  118. help="RPC serving port")
  119. args = parser.parse_args()
  120. handler = RPCHandler()
  121. handler.register_function(chat)
  122. handler.register_function(chat_streamly)
  123. models = []
  124. for _ in range(1):
  125. m = AutoModelForCausalLM.from_pretrained(args.model_name,
  126. device_map="auto",
  127. torch_dtype='auto')
  128. models.append(m)
  129. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  130. # Run the server
  131. rpc_server(handler, ('0.0.0.0', args.port),
  132. authkey=b'infiniflow-token4kevinhu')