Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import argparse
  2. import pickle
  3. import random
  4. import time
  5. from multiprocessing.connection import Listener
  6. from threading import Thread
  7. from transformers import AutoModelForCausalLM, AutoTokenizer
  8. class RPCHandler:
  9. def __init__(self):
  10. self._functions = {}
  11. def register_function(self, func):
  12. self._functions[func.__name__] = func
  13. def handle_connection(self, connection):
  14. try:
  15. while True:
  16. # Receive a message
  17. func_name, args, kwargs = pickle.loads(connection.recv())
  18. # Run the RPC and send a response
  19. try:
  20. r = self._functions[func_name](*args, **kwargs)
  21. connection.send(pickle.dumps(r))
  22. except Exception as e:
  23. connection.send(pickle.dumps(e))
  24. except EOFError:
  25. pass
  26. def rpc_server(hdlr, address, authkey):
  27. sock = Listener(address, authkey=authkey)
  28. while True:
  29. try:
  30. client = sock.accept()
  31. t = Thread(target=hdlr.handle_connection, args=(client,))
  32. t.daemon = True
  33. t.start()
  34. except Exception as e:
  35. print("【EXCEPTION】:", str(e))
  36. models = []
  37. tokenizer = None
  38. def chat(messages, gen_conf):
  39. global tokenizer
  40. model = Model()
  41. try:
  42. conf = {
  43. "max_new_tokens": int(
  44. gen_conf.get(
  45. "max_tokens", 256)), "temperature": float(
  46. gen_conf.get(
  47. "temperature", 0.1))}
  48. print(messages, conf)
  49. text = tokenizer.apply_chat_template(
  50. messages,
  51. tokenize=False,
  52. add_generation_prompt=True
  53. )
  54. model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
  55. generated_ids = model.generate(
  56. model_inputs.input_ids,
  57. **conf
  58. )
  59. generated_ids = [
  60. output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
  61. ]
  62. return tokenizer.batch_decode(
  63. generated_ids, skip_special_tokens=True)[0]
  64. except Exception as e:
  65. return str(e)
  66. def Model():
  67. global models
  68. random.seed(time.time())
  69. return random.choice(models)
  70. if __name__ == "__main__":
  71. parser = argparse.ArgumentParser()
  72. parser.add_argument("--model_name", type=str, help="Model name")
  73. parser.add_argument(
  74. "--port",
  75. default=7860,
  76. type=int,
  77. help="RPC serving port")
  78. args = parser.parse_args()
  79. handler = RPCHandler()
  80. handler.register_function(chat)
  81. models = []
  82. for _ in range(1):
  83. m = AutoModelForCausalLM.from_pretrained(args.model_name,
  84. device_map="auto",
  85. torch_dtype='auto')
  86. models.append(m)
  87. tokenizer = AutoTokenizer.from_pretrained(args.model_name)
  88. # Run the server
  89. rpc_server(handler, ('0.0.0.0', args.port),
  90. authkey=b'infiniflow-token4kevinhu')