您最多选择25个主题 主题必须以字母或数字开头,可以包含连字符 (-),并且长度不得超过35个字符

rpc_server.py 3.1KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import argparse
  2. import pickle
  3. import random
  4. import time
  5. from multiprocessing.connection import Listener
  6. from threading import Thread
  7. from transformers import AutoModelForCausalLM, AutoTokenizer
  8. class RPCHandler:
  9. def __init__(self):
  10. self._functions = {}
  11. def register_function(self, func):
  12. self._functions[func.__name__] = func
  13. def handle_connection(self, connection):
  14. try:
  15. while True:
  16. # Receive a message
  17. func_name, args, kwargs = pickle.loads(connection.recv())
  18. # Run the RPC and send a response
  19. try:
  20. r = self._functions[func_name](*args, **kwargs)
  21. connection.send(pickle.dumps(r))
  22. except Exception as e:
  23. connection.send(pickle.dumps(e))
  24. except EOFError:
  25. pass
  26. def rpc_server(hdlr, address, authkey):
  27. sock = Listener(address, authkey=authkey)
  28. while True:
  29. try:
  30. client = sock.accept()
  31. t = Thread(target=hdlr.handle_connection, args=(client,))
  32. t.daemon = True
  33. t.start()
  34. except Exception as e:
  35. print("【EXCEPTION】:", str(e))
  36. models = []
  37. tokenizer = None
  38. def chat(messages, gen_conf):
  39. global tokenizer
  40. model = Model()
  41. try:
  42. conf = {
  43. "max_new_tokens": int(
  44. gen_conf.get(
  45. "max_tokens", 256)), "temperature": float(
  46. gen_conf.get(
  47. "temperature", 0.1))}
  48. print(messages, conf)
  49. text = tokenizer.apply_chat_template(
  50. messages,
  51. tokenize=False,
  52. add_generation_prompt=True
  53. )
  54. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  55. generated_ids = model.generate(
  56. model_inputs.input_ids,
  57. **conf
  58. )
  59. generated_ids = [
  60. output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  61. ]
  62. return tokenizer.batch_decode(
  63. generated_ids, skip_special_tokens=True)[0]
  64. except Exception as e:
  65. return str(e)
  66. def Model():
  67. global models
  68. random.seed(time.time())
  69. return random.choice(models)
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument("--model_name", type=str, help="Model name")
  73. parser.add_argument(
  74. "--port",
  75. default=7860,
  76. type=int,
  77. help="RPC serving port")
  78. args = parser.parse_args()
  79. handler = RPCHandler()
  80. handler.register_function(chat)
  81. models = []
  82. for _ in range(1):
  83. m = AutoModelForCausalLM.from_pretrained(args.model_name,
  84. device_map="auto",
  85. torch_dtype='auto')
  86. models.append(m)
  87. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  88. # Run the server
  89. rpc_server(handler, ('0.0.0.0', args.port),
  90. authkey=b'infiniflow-token4kevinhu')