You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fl_restful_tool.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. #!/usr/bin/env python3
  2. # coding=UTF-8
  3. # Copyright 2020 Huawei Technologies Co., Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. # ============================================================================
  17. """
  18. Function:
  19. Use to control the federated learning cluster
  20. Usage:
  21. python fl_restful_tool.py [http_type] [ip] [port] [request_name] [server_num] [instance_param] [metrics_file_path]
  22. """
  23. import argparse
  24. import json
  25. import os
  26. import warnings
  27. import requests
  28. class Status:
  29. success = "0"
  30. failed = "1"
  31. class Restful:
  32. """
  33. Define restful interface constant
  34. """
  35. scale = "scale"
  36. scaleout = "scaleout"
  37. scalein = "scalein"
  38. nodes = "nodes"
  39. getInstanceDetail = "getInstanceDetail"
  40. newInstance = "newInstance"
  41. queryInstance = "queryInstance"
  42. enableFLS = "enableFLS"
  43. disableFLS = "disableFLS"
  44. state = "state"
  45. scaleoutRollback = "scaleoutRollback"
  46. warnings.filterwarnings('ignore')
  47. parser = argparse.ArgumentParser()
  48. parser.add_argument("--http_type", type=str, default="http", help="http or https")
  49. parser.add_argument("--ip", type=str, default="127.0.0.1")
  50. parser.add_argument("--port", type=int, default=6666)
  51. # scaleout scalein nodes
  52. parser.add_argument("--request_name", type=str, default="")
  53. parser.add_argument("--server_num", type=int, default=0)
  54. # "start_fl_job_threshold=20,start_fl_job_time_window=2000..."
  55. parser.add_argument("--instance_param", type=str, default="")
  56. parser.add_argument("--metrics_file_path", type=str, default="/opt/huawei/mindspore/hybrid_albert/metrics.json")
  57. args, _ = parser.parse_known_args()
  58. http_type = args.http_type
  59. ip = args.ip
  60. port = args.port
  61. request_name = args.request_name
  62. server_num = args.server_num
  63. instance_param = args.instance_param
  64. metrics_file_path = args.metrics_file_path
  65. headers = {'Content-Type': 'application/json'}
  66. session = requests.Session()
  67. base_url = http_type + "://" + ip + ":" + str(port) + "/"
  68. def call_scale():
  69. """
  70. call cluster scale out or scale in
  71. """
  72. if server_num == 0:
  73. return process_self_define_json(Status.failed, "error. server_num is 0")
  74. node_ids = json.loads(call_nodes())["result"]
  75. cluster_abstract_node_num = len(node_ids)
  76. if cluster_abstract_node_num == 0:
  77. return process_self_define_json(Status.failed, "error. cluster abstract node num is 0")
  78. cluster_server_node_num = 0
  79. cluster_worker_node_num = 0
  80. cluster_server_node_base_name = ''
  81. for i in range(0, cluster_abstract_node_num):
  82. if node_ids[i]['role'] == 'WORKER':
  83. cluster_worker_node_num = cluster_worker_node_num + 1
  84. elif node_ids[i]['role'] == 'SERVER':
  85. cluster_server_node_num = cluster_server_node_num + 1
  86. cluster_server_node_name = str(node_ids[i]['nodeId'])
  87. index = cluster_server_node_name.rindex('-')
  88. cluster_server_node_base_name = cluster_server_node_name[0:index]
  89. else:
  90. pass
  91. if cluster_server_node_num == server_num:
  92. return process_self_define_json(Status.failed, "error. cluster server num is same with server_num.")
  93. if cluster_server_node_num > server_num:
  94. scale_in_len = cluster_server_node_num - server_num
  95. scale_in_node_ids = []
  96. for index in range(cluster_server_node_num - scale_in_len, cluster_server_node_num):
  97. scale_in_node_name = cluster_server_node_base_name + "-" + str(index)
  98. scale_in_node_ids.append(scale_in_node_name)
  99. return call_scalein(scale_in_node_ids)
  100. return call_scaleout(server_num - cluster_server_node_num)
  101. def call_scaleout(scale_out_server_num, scale_out_worker_num=0):
  102. url = base_url + "scaleout"
  103. data = {"server_num": scale_out_server_num, "worker_num": scale_out_worker_num}
  104. res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
  105. res_json = json.loads(res.text)
  106. if res_json["code"] == Status.failed:
  107. return process_self_define_json(Status.failed, res_json["error_message"])
  108. result = "scale out server num is " + str(scale_out_server_num)
  109. return process_result_json(Status.success, res_json["message"], result)
  110. def call_scaleout_rollback():
  111. url = base_url + Restful.scaleoutRollback
  112. res = session.get(url, verify=False)
  113. res_json = json.loads(res.text)
  114. if res_json["code"] == Status.failed:
  115. return process_self_define_json(Status.failed, res_json["error_message"])
  116. return process_self_define_json(Status.success, res_json["message"])
  117. def call_scalein(scale_in_node_ids):
  118. """
  119. call cluster to scale in
  120. """
  121. if not scale_in_node_ids:
  122. return process_self_define_json(Status.failed, "error. node ids is empty.")
  123. url = base_url + "scalein"
  124. data = {"node_ids": scale_in_node_ids}
  125. res = session.post(url, headers=headers, verify=False, data=json.dumps(data))
  126. res_json = json.loads(res.text)
  127. if res_json["code"] == Status.failed:
  128. return process_self_define_json(Status.failed, res_json["error_message"])
  129. result = "scale in node ids is " + str(scale_in_node_ids)
  130. return process_result_json(Status.success, res_json["message"], result)
  131. def call_nodes():
  132. url = base_url + Restful.nodes
  133. res = session.get(url, verify=False)
  134. res_json = json.loads(res.text)
  135. if res_json["code"] == Status.failed:
  136. return process_self_define_json(Status.failed, res_json["error_message"])
  137. return process_result_json(Status.success, res_json["message"], res_json["nodeIds"])
  138. def call_get_instance_detail():
  139. """
  140. call cluster get instance detail
  141. """
  142. if not os.path.exists(metrics_file_path):
  143. return process_self_define_json(Status.failed, "error. metrics file is not existed.")
  144. ans_json_obj = {}
  145. joined_client_num_list = []
  146. rejected_client_num_list = []
  147. metrics_auc_list = []
  148. metrics_loss_list = []
  149. iteration_execution_time_list = []
  150. with open(metrics_file_path, 'r') as f:
  151. metrics_list = f.readlines()
  152. if not metrics_list:
  153. return process_self_define_json(Status.failed, "error. metrics file has no content")
  154. for metrics in metrics_list:
  155. json_obj = json.loads(metrics)
  156. iteration_execution_time_list.append(json_obj['iterationExecutionTime'])
  157. joined_client_num_list.append(json_obj['joinedClientNum'])
  158. rejected_client_num_list.append(json_obj['rejectedClientNum'])
  159. metrics_auc_list.append(json_obj['metricsAuc'])
  160. metrics_loss_list.append(json_obj['metricsLoss'])
  161. last_metrics = metrics_list[len(metrics_list) - 1]
  162. last_metrics_obj = json.loads(last_metrics)
  163. ans_json_obj["code"] = Status.success
  164. ans_json_obj["describe"] = "get instance metrics detail successful."
  165. ans_json_obj["result"] = {}
  166. ans_json_obj["result"]['currentIteration'] = last_metrics_obj['currentIteration']
  167. ans_json_obj["result"]['flIterationNum'] = last_metrics_obj['flIterationNum']
  168. ans_json_obj["result"]['flName'] = last_metrics_obj['flName']
  169. ans_json_obj["result"]['instanceStatus'] = last_metrics_obj['instanceStatus']
  170. ans_json_obj["result"]['iterationExecutionTime'] = iteration_execution_time_list
  171. ans_json_obj["result"]['joinedClientNum'] = joined_client_num_list
  172. ans_json_obj["result"]['rejectedClientNum'] = rejected_client_num_list
  173. ans_json_obj["result"]['metricsAuc'] = metrics_auc_list
  174. ans_json_obj["result"]['metricsLoss'] = metrics_loss_list
  175. return json.dumps(ans_json_obj)
  176. def call_new_instance():
  177. """
  178. call cluster new instance
  179. """
  180. if instance_param == "":
  181. return process_self_define_json(Status.failed, "error. instance_param is empty.")
  182. instance_param_list = instance_param.split(sep=",")
  183. instance_param_json_obj = {}
  184. url = base_url + Restful.newInstance
  185. for cur in instance_param_list:
  186. pair = cur.split(sep="=")
  187. instance_param_json_obj[pair[0]] = float(pair[1])
  188. data = json.dumps(instance_param_json_obj)
  189. res = session.post(url, verify=False, data=data)
  190. res_json = json.loads(res.text)
  191. if res_json["code"] == Status.failed:
  192. return process_self_define_json(Status.failed, res_json["error_message"])
  193. return process_self_define_json(Status.success, res_json["message"])
  194. def call_query_instance():
  195. url = base_url + Restful.queryInstance
  196. res = session.post(url, verify=False)
  197. res_json = json.loads(res.text)
  198. if res_json["code"] == Status.failed:
  199. return process_self_define_json(Status.failed, res_json["error_message"])
  200. return process_result_json(Status.success, res_json["message"], res_json["result"])
  201. def call_enable_fls():
  202. url = base_url + Restful.enableFLS
  203. res = session.post(url, verify=False)
  204. res_json = json.loads(res.text)
  205. if res_json["code"] == Status.failed:
  206. return process_self_define_json(Status.failed, res_json["error_message"])
  207. return process_self_define_json(Status.success, res_json["message"])
  208. def call_disable_fls():
  209. url = base_url + Restful.disableFLS
  210. res = session.post(url, verify=False)
  211. res_json = json.loads(res.text)
  212. if res_json["code"] == Status.failed:
  213. return process_self_define_json(Status.failed, res_json["error_message"])
  214. return process_self_define_json(Status.success, res_json["message"])
  215. def call_state():
  216. url = base_url + Restful.state
  217. res = session.get(url, verify=False)
  218. res_json = json.loads(res.text)
  219. if res_json["code"] == Status.failed:
  220. return process_self_define_json(Status.failed, res_json["error_message"])
  221. result = res_json['cluster_state']
  222. return process_result_json(Status.success, res_json["message"], result)
  223. def process_result_json(code, describe, result):
  224. result_dict = {"code": code, "describe": describe, "result": result}
  225. return json.dumps(result_dict)
  226. def process_self_define_json(code, describe):
  227. result_dict = {"code": code, "describe": describe}
  228. return json.dumps(result_dict)
  229. if __name__ == '__main__':
  230. if request_name == Restful.scale:
  231. print(call_scale())
  232. elif request_name == Restful.nodes:
  233. print(call_nodes())
  234. elif request_name == Restful.getInstanceDetail:
  235. print(call_get_instance_detail())
  236. elif request_name == Restful.newInstance:
  237. print(call_new_instance())
  238. elif request_name == Restful.queryInstance:
  239. print(call_query_instance())
  240. elif request_name == Restful.enableFLS:
  241. print(call_enable_fls())
  242. elif request_name == Restful.disableFLS:
  243. print(call_disable_fls())
  244. elif request_name == Restful.state:
  245. print(call_state())
  246. elif request_name == Restful.scaleoutRollback:
  247. print(call_scaleout_rollback())
  248. else:
  249. print(process_self_define_json(1, "error. request_name is not found!"))