Nevar pievienot vairāk kā 25 tēmas Tēmai ir jāsākas ar burtu vai ciparu, tā var saturēt domu zīmes ('-') un var būt līdz 35 simboliem gara.

service_registry.py 6.6KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #
  2. # Copyright 2019 The FATE Authors. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. import socket
  17. from pathlib import Path
  18. from web_server import utils
  19. from .db_models import DB, ServiceRegistryInfo, ServerRegistryInfo
  20. from .reload_config_base import ReloadConfigBase
  21. class ServiceRegistry(ReloadConfigBase):
  22. @classmethod
  23. @DB.connection_context()
  24. def load_service(cls, **kwargs) -> [ServiceRegistryInfo]:
  25. service_registry_list = ServiceRegistryInfo.query(**kwargs)
  26. return [service for service in service_registry_list]
  27. @classmethod
  28. @DB.connection_context()
  29. def save_service_info(cls, server_name, service_name, uri, method="POST", server_info=None, params=None, data=None, headers=None, protocol="http"):
  30. if not server_info:
  31. server_list = ServerRegistry.query_server_info_from_db(server_name=server_name)
  32. if not server_list:
  33. raise Exception(f"no found server {server_name}")
  34. server_info = server_list[0]
  35. url = f"{server_info.f_protocol}://{server_info.f_host}:{server_info.f_port}{uri}"
  36. else:
  37. url = f"{server_info.get('protocol', protocol)}://{server_info.get('host')}:{server_info.get('port')}{uri}"
  38. service_info = {
  39. "f_server_name": server_name,
  40. "f_service_name": service_name,
  41. "f_url": url,
  42. "f_method": method,
  43. "f_params": params if params else {},
  44. "f_data": data if data else {},
  45. "f_headers": headers if headers else {}
  46. }
  47. entity_model, status = ServiceRegistryInfo.get_or_create(
  48. f_server_name=server_name,
  49. f_service_name=service_name,
  50. defaults=service_info)
  51. if status is False:
  52. for key in service_info:
  53. setattr(entity_model, key, service_info[key])
  54. entity_model.save(force_insert=False)
  55. class ServerRegistry(ReloadConfigBase):
  56. FATEBOARD = None
  57. FATE_ON_STANDALONE = None
  58. FATE_ON_EGGROLL = None
  59. FATE_ON_SPARK = None
  60. MODEL_STORE_ADDRESS = None
  61. SERVINGS = None
  62. FATEMANAGER = None
  63. STUDIO = None
  64. @classmethod
  65. def load(cls):
  66. cls.load_server_info_from_conf()
  67. cls.load_server_info_from_db()
  68. @classmethod
  69. def load_server_info_from_conf(cls):
  70. path = Path(utils.file_utils.get_project_base_directory()) / 'conf' / utils.SERVICE_CONF
  71. conf = utils.file_utils.load_yaml_conf(path)
  72. if not isinstance(conf, dict):
  73. raise ValueError('invalid config file')
  74. local_path = path.with_name(f'local.{utils.SERVICE_CONF}')
  75. if local_path.exists():
  76. local_conf = utils.file_utils.load_yaml_conf(local_path)
  77. if not isinstance(local_conf, dict):
  78. raise ValueError('invalid local config file')
  79. conf.update(local_conf)
  80. for k, v in conf.items():
  81. if isinstance(v, dict):
  82. setattr(cls, k.upper(), v)
  83. @classmethod
  84. def register(cls, server_name, server_info):
  85. cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol=server_info.get("protocol", "http"))
  86. setattr(cls, server_name, server_info)
  87. @classmethod
  88. def save(cls, service_config):
  89. update_server = {}
  90. for server_name, server_info in service_config.items():
  91. cls.parameter_check(server_info)
  92. api_info = server_info.pop("api", {})
  93. for service_name, info in api_info.items():
  94. ServiceRegistry.save_service_info(server_name, service_name, uri=info.get('uri'), method=info.get('method', 'POST'), server_info=server_info)
  95. cls.save_server_info_to_db(server_name, server_info.get("host"), server_info.get("port"), protocol="http")
  96. setattr(cls, server_name.upper(), server_info)
  97. return update_server
  98. @classmethod
  99. def parameter_check(cls, service_info):
  100. if "host" in service_info and "port" in service_info:
  101. cls.connection_test(service_info.get("host"), service_info.get("port"))
  102. @classmethod
  103. def connection_test(cls, ip, port):
  104. s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  105. result = s.connect_ex((ip, port))
  106. if result != 0:
  107. raise ConnectionRefusedError(f"connection refused: host {ip}, port {port}")
  108. @classmethod
  109. def query(cls, service_name, default=None):
  110. service_info = getattr(cls, service_name, default)
  111. if not service_info:
  112. service_info = utils.get_base_config(service_name, default)
  113. return service_info
  114. @classmethod
  115. @DB.connection_context()
  116. def query_server_info_from_db(cls, server_name=None) -> [ServerRegistryInfo]:
  117. if server_name:
  118. server_list = ServerRegistryInfo.select().where(ServerRegistryInfo.f_server_name==server_name.upper())
  119. else:
  120. server_list = ServerRegistryInfo.select()
  121. return [server for server in server_list]
  122. @classmethod
  123. @DB.connection_context()
  124. def load_server_info_from_db(cls):
  125. for server in cls.query_server_info_from_db():
  126. server_info = {
  127. "host": server.f_host,
  128. "port": server.f_port,
  129. "protocol": server.f_protocol
  130. }
  131. setattr(cls, server.f_server_name.upper(), server_info)
  132. @classmethod
  133. @DB.connection_context()
  134. def save_server_info_to_db(cls, server_name, host, port, protocol="http"):
  135. server_info = {
  136. "f_server_name": server_name,
  137. "f_host": host,
  138. "f_port": port,
  139. "f_protocol": protocol
  140. }
  141. entity_model, status = ServerRegistryInfo.get_or_create(
  142. f_server_name=server_name,
  143. defaults=server_info)
  144. if status is False:
  145. for key in server_info:
  146. setattr(entity_model, key, server_info[key])
  147. entity_model.save(force_insert=False)