Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

rpc_server.py 3.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import argparse
  2. import pickle
  3. import random
  4. import time
  5. from multiprocessing.connection import Listener
  6. from threading import Thread
  7. from transformers import AutoModelForCausalLM, AutoTokenizer
  8. def torch_gc():
  9. try:
  10. import torch
  11. if torch.cuda.is_available():
  12. # with torch.cuda.device(DEVICE):
  13. torch.cuda.empty_cache()
  14. torch.cuda.ipc_collect()
  15. elif torch.backends.mps.is_available():
  16. try:
  17. from torch.mps import empty_cache
  18. empty_cache()
  19. except Exception as e:
  20. pass
  21. except Exception:
  22. pass
  23. class RPCHandler:
  24. def __init__(self):
  25. self._functions = {}
  26. def register_function(self, func):
  27. self._functions[func.__name__] = func
  28. def handle_connection(self, connection):
  29. try:
  30. while True:
  31. # Receive a message
  32. func_name, args, kwargs = pickle.loads(connection.recv())
  33. # Run the RPC and send a response
  34. try:
  35. r = self._functions[func_name](*args, **kwargs)
  36. connection.send(pickle.dumps(r))
  37. except Exception as e:
  38. connection.send(pickle.dumps(e))
  39. except EOFError:
  40. pass
  41. def rpc_server(hdlr, address, authkey):
  42. sock = Listener(address, authkey=authkey)
  43. while True:
  44. try:
  45. client = sock.accept()
  46. t = Thread(target=hdlr.handle_connection, args=(client,))
  47. t.daemon = True
  48. t.start()
  49. except Exception as e:
  50. print("【EXCEPTION】:", str(e))
  51. models = []
  52. tokenizer = None
  53. def chat(messages, gen_conf):
  54. global tokenizer
  55. model = Model()
  56. try:
  57. torch_gc()
  58. conf = {
  59. "max_new_tokens": int(
  60. gen_conf.get(
  61. "max_tokens", 256)), "temperature": float(
  62. gen_conf.get(
  63. "temperature", 0.1))}
  64. print(messages, conf)
  65. text = tokenizer.apply_chat_template(
  66. messages,
  67. tokenize=False,
  68. add_generation_prompt=True
  69. )
  70. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  71. generated_ids = model.generate(
  72. model_inputs.input_ids,
  73. **conf
  74. )
  75. generated_ids = [
  76. output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  77. ]
  78. return tokenizer.batch_decode(
  79. generated_ids, skip_special_tokens=True)[0]
  80. except Exception as e:
  81. return str(e)
  82. def Model():
  83. global models
  84. random.seed(time.time())
  85. return random.choice(models)
  86. if __name__ == "__main__":
  87. parser = argparse.ArgumentParser()
  88. parser.add_argument("--model_name", type=str, help="Model name")
  89. parser.add_argument(
  90. "--port",
  91. default=7860,
  92. type=int,
  93. help="RPC serving port")
  94. args = parser.parse_args()
  95. handler = RPCHandler()
  96. handler.register_function(chat)
  97. models = []
  98. for _ in range(1):
  99. m = AutoModelForCausalLM.from_pretrained(args.model_name,
  100. device_map="auto",
  101. torch_dtype='auto')
  102. models.append(m)
  103. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  104. # Run the server
  105. rpc_server(handler, ('0.0.0.0', args.port),
  106. authkey=b'infiniflow-token4kevinhu')