You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fl_restful_tool.py 12 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Function:
  17. Use to control the federated learning cluster
  18. Usage:
  19. python fl_restful_tool.py [http_type] [ip] [port] [request_name] [server_num] [instance_param] [metrics_file_path]
  20. """
  21. import argparse
  22. import json
  23. import os
  24. import warnings
  25. from enum import Enum
  26. import requests
  27. class Status(Enum):
  28. """
  29. Response Status
  30. """
  31. SUCCESS = "0"
  32. FAILED = "1"
  33. class Restful(Enum):
  34. """
  35. Define restful interface constant
  36. """
  37. SCALE = "scale"
  38. SCALE_OUT = "scaleout"
  39. SCALE_IN = "scalein"
  40. NODES = "nodes"
  41. GET_INSTANCE_DETAIL = "getInstanceDetail"
  42. NEW_INSTANCE = "newInstance"
  43. QUERY_INSTANCE = "queryInstance"
  44. ENABLE_FLS = "enableFLS"
  45. DISABLE_FLS = "disableFLS"
  46. STATE = "state"
  47. SCALE_OUT_ROLLBACK = "scaleoutRollback"
  48. warnings.filterwarnings('ignore')
  49. parser = argparse.ArgumentParser()
  50. parser.add_argument("--http_type", type=str, default="http", help="http or https")
  51. parser.add_argument("--ip", type=str, default="127.0.0.1")
  52. parser.add_argument("--port", type=int, default=6666)
  53. parser.add_argument("--request_name", type=str, default="")
  54. parser.add_argument("--server_num", type=int, default=0)
  55. parser.add_argument("--instance_param", type=str, default="")
  56. parser.add_argument("--metrics_file_path", type=str, default="/opt/huawei/mindspore/hybrid_albert/metrics.json")
  57. args, _ = parser.parse_known_args()
  58. http_type = args.http_type
  59. ip = args.ip
  60. port = args.port
  61. request_name = args.request_name
  62. server_num = args.server_num
  63. instance_param = args.instance_param
  64. metrics_file_path = args.metrics_file_path
  65. headers = {'Content-Type': 'application/json'}
  66. session = requests.Session()
  67. base_url = http_type + "://" + ip + ":" + str(port) + "/"
  68. def call_scale():
  69. """
  70. call cluster scale out or scale in
  71. """
  72. if server_num == 0:
  73. return process_self_define_json(Status.FAILED.value, "error. server_num is 0")
  74. node_ids = json.loads(call_nodes())["result"]
  75. cluster_abstract_node_num = len(node_ids)
  76. if cluster_abstract_node_num == 0:
  77. return process_self_define_json(Status.FAILED.value, "error. cluster abstract node num is 0")
  78. cluster_server_node_num = 0
  79. cluster_worker_node_num = 0
  80. cluster_server_node_base_name = ''
  81. for i in range(0, cluster_abstract_node_num):
  82. if node_ids[i]['role'] == 'WORKER':
  83. cluster_worker_node_num = cluster_worker_node_num + 1
  84. elif node_ids[i]['role'] == 'SERVER':
  85. cluster_server_node_num = cluster_server_node_num + 1
  86. cluster_server_node_name = str(node_ids[i]['nodeId'])
  87. index = cluster_server_node_name.rindex('-')
  88. cluster_server_node_base_name = cluster_server_node_name[0:index]
  89. else:
  90. pass
  91. if cluster_server_node_num == server_num:
  92. return process_self_define_json(Status.FAILED.value, "error. cluster server num is same with server_num.")
  93. if cluster_server_node_num > server_num:
  94. scale_in_len = cluster_server_node_num - server_num
  95. scale_in_node_ids = []
  96. for index in range(cluster_server_node_num - scale_in_len, cluster_server_node_num):
  97. scale_in_node_name = cluster_server_node_base_name + "-" + str(index)
  98. scale_in_node_ids.append(scale_in_node_name)
  99. return call_scalein(scale_in_node_ids)
  100. return call_scaleout(server_num - cluster_server_node_num)
  101. def call_scaleout(scale_out_server_num, scale_out_worker_num=0):
  102. """
  103. call scaleout
  104. """
  105. url = base_url + Restful.SCALE_OUT.value
  106. data = {"server_num": scale_out_server_num, "worker_num": scale_out_worker_num}
  107. res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
  108. res_json = json.loads(res.text)
  109. if res_json["code"] == Status.FAILED.value:
  110. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  111. result = "scale out server num is " + str(scale_out_server_num)
  112. return process_result_json(Status.SUCCESS.value, res_json["message"], result)
  113. def call_scaleout_rollback():
  114. """
  115. call scaleout rollback
  116. """
  117. url = base_url + Restful.SCALE_OUT_ROLLBACK.value
  118. res = session.get(url, verify=False)
  119. res_json = json.loads(res.text)
  120. if res_json["code"] == Status.FAILED.value:
  121. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  122. return process_self_define_json(Status.SUCCESS.value, res_json["message"])
  123. def call_scalein(scale_in_node_ids):
  124. """
  125. call cluster to scale in
  126. """
  127. if not scale_in_node_ids:
  128. return process_self_define_json(Status.FAILED.value, "error. node ids is empty.")
  129. url = base_url + Restful.SCALE_IN.value
  130. data = {"node_ids": scale_in_node_ids}
  131. res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
  132. res_json = json.loads(res.text)
  133. if res_json["code"] == Status.FAILED.value:
  134. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  135. result = "scale in node ids is " + str(scale_in_node_ids)
  136. return process_result_json(Status.SUCCESS.value, res_json["message"], result)
  137. def call_nodes():
  138. """
  139. get nodes info
  140. """
  141. url = base_url + Restful.NODES.value
  142. res = session.get(url, verify=False)
  143. res_json = json.loads(res.text)
  144. if res_json["code"] == Status.FAILED.value:
  145. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  146. return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["nodeIds"])
  147. def call_get_instance_detail():
  148. """
  149. get cluster instance detail
  150. """
  151. if not os.path.exists(metrics_file_path):
  152. return process_self_define_json(Status.FAILED.value, "error. metrics file is not existed.")
  153. ans_json_obj = {}
  154. metrics_auc_list = []
  155. metrics_loss_list = []
  156. iteration_execution_time_list = []
  157. client_visited_info_list = []
  158. with open(metrics_file_path, 'r') as f:
  159. metrics_list = f.readlines()
  160. if not metrics_list:
  161. return process_self_define_json(Status.FAILED.value, "error. metrics file has no content")
  162. for metrics in metrics_list:
  163. json_obj = json.loads(metrics)
  164. iteration_execution_time_list.append(json_obj['iterationExecutionTime'])
  165. client_visited_info_list.append(json_obj['clientVisitedInfo'])
  166. metrics_auc_list.append(json_obj['metricsAuc'])
  167. metrics_loss_list.append(json_obj['metricsLoss'])
  168. last_metrics = metrics_list[len(metrics_list) - 1]
  169. last_metrics_obj = json.loads(last_metrics)
  170. ans_json_obj["code"] = Status.SUCCESS.value
  171. ans_json_obj["describe"] = "get instance metrics detail successful."
  172. ans_json_obj["result"] = {}
  173. ans_json_result = ans_json_obj.get("result")
  174. ans_json_result['currentIteration'] = last_metrics_obj['currentIteration']
  175. ans_json_result['flIterationNum'] = last_metrics_obj['flIterationNum']
  176. ans_json_result['flName'] = last_metrics_obj['flName']
  177. ans_json_result['instanceStatus'] = last_metrics_obj['instanceStatus']
  178. ans_json_result['iterationExecutionTime'] = iteration_execution_time_list
  179. ans_json_result['clientVisitedInfo'] = client_visited_info_list
  180. ans_json_result['metricsAuc'] = metrics_auc_list
  181. ans_json_result['metricsLoss'] = metrics_loss_list
  182. return json.dumps(ans_json_obj)
  183. def call_new_instance():
  184. """
  185. call cluster new instance
  186. """
  187. if instance_param == "":
  188. return process_self_define_json(Status.FAILED.value, "error. instance_param is empty.")
  189. instance_param_list = instance_param.split(sep=",")
  190. instance_param_json_obj = {}
  191. url = base_url + Restful.NEW_INSTANCE.value
  192. for cur in instance_param_list:
  193. pair = cur.split(sep="=")
  194. instance_param_json_obj[pair[0]] = float(pair[1])
  195. data = json.dumps(instance_param_json_obj)
  196. res = session.post(url, verify=False, data=data)
  197. res_json = json.loads(res.text)
  198. if res_json["code"] == Status.FAILED.value:
  199. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  200. return process_self_define_json(Status.SUCCESS.value, res_json["message"])
  201. def call_query_instance():
  202. """
  203. query cluster instance
  204. """
  205. url = base_url + Restful.QUERY_INSTANCE.value
  206. res = session.post(url, verify=False)
  207. res_json = json.loads(res.text)
  208. if res_json["code"] == Status.FAILED.value:
  209. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  210. return process_result_json(Status.SUCCESS.value, res_json["message"], res_json["result"])
  211. def call_enable_fls():
  212. """
  213. enable cluster fls
  214. """
  215. url = base_url + Restful.ENABLE_FLS.value
  216. res = session.post(url, verify=False)
  217. res_json = json.loads(res.text)
  218. if res_json["code"] == Status.FAILED.value:
  219. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  220. return process_self_define_json(Status.SUCCESS.value, res_json["message"])
  221. def call_disable_fls():
  222. """
  223. disable cluster fls
  224. """
  225. url = base_url + Restful.DISABLE_FLS.value
  226. res = session.post(url, verify=False)
  227. res_json = json.loads(res.text)
  228. if res_json["code"] == Status.FAILED.value:
  229. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  230. return process_self_define_json(Status.SUCCESS.value, res_json["message"])
  231. def call_state():
  232. """
  233. get cluster state
  234. """
  235. url = base_url + Restful.STATE.value
  236. res = session.get(url, verify=False)
  237. res_json = json.loads(res.text)
  238. if res_json["code"] == Status.FAILED.value:
  239. return process_self_define_json(Status.FAILED.value, res_json["error_message"])
  240. result = res_json['cluster_state']
  241. return process_result_json(Status.SUCCESS.value, res_json["message"], result)
  242. def process_result_json(code, describe, result):
  243. """
  244. process result json
  245. """
  246. result_dict = {"code": code, "describe": describe, "result": result}
  247. return json.dumps(result_dict)
  248. def process_self_define_json(code, describe):
  249. """
  250. process self define json
  251. """
  252. result_dict = {"code": code, "describe": describe}
  253. return json.dumps(result_dict)
  254. if __name__ == '__main__':
  255. if request_name == Restful.SCALE.value:
  256. print(call_scale())
  257. elif request_name == Restful.NODES.value:
  258. print(call_nodes())
  259. elif request_name == Restful.GET_INSTANCE_DETAIL.value:
  260. print(call_get_instance_detail())
  261. elif request_name == Restful.NEW_INSTANCE.value:
  262. print(call_new_instance())
  263. elif request_name == Restful.QUERY_INSTANCE.value:
  264. print(call_query_instance())
  265. elif request_name == Restful.ENABLE_FLS.value:
  266. print(call_enable_fls())
  267. elif request_name == Restful.DISABLE_FLS.value:
  268. print(call_disable_fls())
  269. elif request_name == Restful.STATE.value:
  270. print(call_state())
  271. elif request_name == Restful.SCALE_OUT_ROLLBACK.value:
  272. print(call_scaleout_rollback())
  273. else:
  274. print(process_self_define_json(1, "error. request_name is not found!"))