You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

rpc_server.py 3.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import argparse
  2. import pickle
  3. import random
  4. import time
  5. from multiprocessing.connection import Listener
  6. from threading import Thread
  7. from transformers import AutoModelForCausalLM, AutoTokenizer
  8. class RPCHandler:
  9. def __init__(self):
  10. self._functions = {}
  11. def register_function(self, func):
  12. self._functions[func.__name__] = func
  13. def handle_connection(self, connection):
  14. try:
  15. while True:
  16. # Receive a message
  17. func_name, args, kwargs = pickle.loads(connection.recv())
  18. # Run the RPC and send a response
  19. try:
  20. r = self._functions[func_name](*args, **kwargs)
  21. connection.send(pickle.dumps(r))
  22. except Exception as e:
  23. connection.send(pickle.dumps(e))
  24. except EOFError:
  25. pass
  26. def rpc_server(hdlr, address, authkey):
  27. sock = Listener(address, authkey=authkey)
  28. while True:
  29. try:
  30. client = sock.accept()
  31. t = Thread(target=hdlr.handle_connection, args=(client,))
  32. t.daemon = True
  33. t.start()
  34. except Exception as e:
  35. print("【EXCEPTION】:", str(e))
  36. models = []
  37. tokenizer = None
  38. def chat(messages, gen_conf):
  39. global tokenizer
  40. model = Model()
  41. try:
  42. conf = {
  43. "max_new_tokens": int(
  44. gen_conf.get(
  45. "max_tokens", 256)), "temperature": float(
  46. gen_conf.get(
  47. "temperature", 0.1))}
  48. print(messages, conf)
  49. text = tokenizer.apply_chat_template(
  50. messages,
  51. tokenize=False,
  52. add_generation_prompt=True
  53. )
  54. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  55. generated_ids = model.generate(
  56. model_inputs.input_ids,
  57. **conf
  58. )
  59. generated_ids = [
  60. output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  61. ]
  62. return tokenizer.batch_decode(
  63. generated_ids, skip_special_tokens=True)[0]
  64. except Exception as e:
  65. return str(e)
  66. def Model():
  67. global models
  68. random.seed(time.time())
  69. return random.choice(models)
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument("--model_name", type=str, help="Model name")
  73. parser.add_argument(
  74. "--port",
  75. default=7860,
  76. type=int,
  77. help="RPC serving port")
  78. args = parser.parse_args()
  79. handler = RPCHandler()
  80. handler.register_function(chat)
  81. models = []
  82. for _ in range(1):
  83. m = AutoModelForCausalLM.from_pretrained(args.model_name,
  84. device_map="auto",
  85. torch_dtype='auto')
  86. models.append(m)
  87. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  88. # Run the server
  89. rpc_server(handler, ('0.0.0.0', args.port),
  90. authkey=b'infiniflow-token4kevinhu')