You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

simulator.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import argparse
  16. import time
  17. import datetime
  18. import random
  19. import sys
  20. import requests
  21. import flatbuffers
  22. import numpy as np
  23. from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
  24. RequestUpdateModel, ResponseUpdateModel,
  25. FeatureMap, RequestGetModel, ResponseGetModel)
  26. parser = argparse.ArgumentParser()
  27. parser.add_argument("--pid", type=int, default=0)
  28. parser.add_argument("--http_ip", type=str, default="10.113.216.106")
  29. parser.add_argument("--http_port", type=int, default=6666)
  30. parser.add_argument("--use_elb", type=bool, default=False)
  31. parser.add_argument("--server_num", type=int, default=1)
  32. args, _ = parser.parse_known_args()
  33. pid = args.pid
  34. http_ip = args.http_ip
  35. http_port = args.http_port
  36. use_elb = args.use_elb
  37. server_num = args.server_num
  38. str_fl_id = 'fl_lenet_' + str(pid)
  39. server_not_available_rsp = ["The cluster is in safemode.",
  40. "The server's training job is disabled or finished."]
  41. def generate_port():
  42. if not use_elb:
  43. return http_port
  44. port = random.randint(0, 100000) % server_num + http_port
  45. return port
  46. def build_start_fl_job():
  47. start_fl_job_builder = flatbuffers.Builder(1024)
  48. fl_name = start_fl_job_builder.CreateString('fl_test_job')
  49. fl_id = start_fl_job_builder.CreateString(str_fl_id)
  50. data_size = 32
  51. timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
  52. RequestFLJob.RequestFLJobStart(start_fl_job_builder)
  53. RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
  54. RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
  55. RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
  56. RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
  57. fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
  58. start_fl_job_builder.Finish(fl_job_req)
  59. buf = start_fl_job_builder.Output()
  60. return buf
  61. def build_feature_map(builder, names, lengths):
  62. if len(names) != len(lengths):
  63. return None
  64. feature_maps = []
  65. np_data = []
  66. for j, _ in enumerate(names):
  67. name = names[j]
  68. length = lengths[j]
  69. weight_full_name = builder.CreateString(name)
  70. FeatureMap.FeatureMapStartDataVector(builder, length)
  71. weight = np.random.rand(length) * 32
  72. np_data.append(weight)
  73. for idx in range(length - 1, -1, -1):
  74. builder.PrependFloat32(weight[idx])
  75. data = builder.EndVector(length)
  76. FeatureMap.FeatureMapStart(builder)
  77. FeatureMap.FeatureMapAddData(builder, data)
  78. FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
  79. feature_map = FeatureMap.FeatureMapEnd(builder)
  80. feature_maps.append(feature_map)
  81. return feature_maps, np_data
  82. def build_update_model(iteration):
  83. builder_update_model = flatbuffers.Builder(1)
  84. fl_name = builder_update_model.CreateString('fl_test_job')
  85. fl_id = builder_update_model.CreateString(str_fl_id)
  86. timestamp = builder_update_model.CreateString('2020/11/16/19/18')
  87. feature_maps, np_data = build_feature_map(builder_update_model,
  88. ["conv1.weight", "conv2.weight", "fc1.weight",
  89. "fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
  90. [450, 2400, 48000, 10080, 5208, 120, 84, 62])
  91. RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
  92. for single_feature_map in feature_maps:
  93. builder_update_model.PrependUOffsetTRelative(single_feature_map)
  94. feature_map = builder_update_model.EndVector(len(feature_maps))
  95. RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
  96. RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
  97. RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
  98. RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
  99. RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
  100. RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
  101. req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
  102. builder_update_model.Finish(req_update_model)
  103. buf = builder_update_model.Output()
  104. return buf, np_data
  105. def build_get_model(iteration):
  106. builder_get_model = flatbuffers.Builder(1)
  107. fl_name = builder_get_model.CreateString('fl_test_job')
  108. timestamp = builder_get_model.CreateString('2020/12/16/19/18')
  109. RequestGetModel.RequestGetModelStart(builder_get_model)
  110. RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
  111. RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
  112. RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
  113. req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
  114. builder_get_model.Finish(req_get_model)
  115. buf = builder_get_model.Output()
  116. return buf
  117. def datetime_to_timestamp(datetime_obj):
  118. local_timestamp = time.mktime(datetime_obj.timetuple()) * 1000.0 + datetime_obj.microsecond // 1000.0
  119. return local_timestamp
  120. weight_to_idx = {
  121. "conv1.weight": 0,
  122. "conv2.weight": 1,
  123. "fc1.weight": 2,
  124. "fc2.weight": 3,
  125. "fc3.weight": 4,
  126. "fc1.bias": 5,
  127. "fc2.bias": 6,
  128. "fc3.bias": 7
  129. }
  130. session = requests.Session()
  131. current_iteration = 1
  132. np.random.seed(0)
  133. def start_fl_job():
  134. start_fl_job_result = {}
  135. iteration = 0
  136. url = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
  137. print("Start fl job url is ", url)
  138. x = session.post(url, data=build_start_fl_job())
  139. if x.text in server_not_available_rsp:
  140. start_fl_job_result['reason'] = "Restart iteration."
  141. start_fl_job_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
  142. print("Start fl job when safemode.")
  143. return start_fl_job_result, iteration
  144. rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
  145. iteration = rsp_fl_job.Iteration()
  146. if rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
  147. if rsp_fl_job.Retcode() == ResponseCode.ResponseCode.OutOfTime:
  148. start_fl_job_result['reason'] = "Restart iteration."
  149. start_fl_job_result['next_ts'] = int(rsp_fl_job.NextReqTime().decode('utf-8'))
  150. print("Start fl job out of time. Next request at ",
  151. start_fl_job_result['next_ts'], "reason:", rsp_fl_job.Reason())
  152. else:
  153. print("Start fl job failed, return code is ", rsp_fl_job.Retcode())
  154. sys.exit()
  155. else:
  156. start_fl_job_result['reason'] = "Success"
  157. start_fl_job_result['next_ts'] = 0
  158. return start_fl_job_result, iteration
  159. def update_model(iteration):
  160. update_model_result = {}
  161. url = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
  162. print("Update model url:", url, ", iteration:", iteration)
  163. update_model_buf, update_model_np_data = build_update_model(iteration)
  164. x = session.post(url, data=update_model_buf)
  165. if x.text in server_not_available_rsp:
  166. update_model_result['reason'] = "Restart iteration."
  167. update_model_result['next_ts'] = datetime_to_timestamp(datetime.datetime.now()) + 500
  168. print("Update model when safemode.")
  169. return update_model_result, update_model_np_data
  170. rsp_update_model = ResponseUpdateModel.ResponseUpdateModel.GetRootAsResponseUpdateModel(x.content, 0)
  171. if rsp_update_model.Retcode() != ResponseCode.ResponseCode.SUCCEED:
  172. if rsp_update_model.Retcode() == ResponseCode.ResponseCode.OutOfTime:
  173. update_model_result['reason'] = "Restart iteration."
  174. update_model_result['next_ts'] = int(rsp_update_model.NextReqTime().decode('utf-8'))
  175. print("Update model out of time. Next request at ",
  176. update_model_result['next_ts'], "reason:", rsp_update_model.Reason())
  177. else:
  178. print("Update model failed, return code is ", rsp_update_model.Retcode())
  179. sys.exit()
  180. else:
  181. update_model_result['reason'] = "Success"
  182. update_model_result['next_ts'] = 0
  183. return update_model_result, update_model_np_data
  184. def get_model(iteration, update_model_data):
  185. get_model_result = {}
  186. url = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
  187. print("Get model url:", url, ", iteration:", iteration)
  188. while True:
  189. x = session.post(url, data=build_get_model(iteration))
  190. if x.text in server_not_available_rsp:
  191. print("Get model when safemode.")
  192. time.sleep(0.5)
  193. continue
  194. rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
  195. ret_code = rsp_get_model.Retcode()
  196. if ret_code == ResponseCode.ResponseCode.SUCCEED:
  197. break
  198. elif ret_code == ResponseCode.ResponseCode.SucNotReady:
  199. time.sleep(0.5)
  200. continue
  201. else:
  202. print("Get model failed, return code is ", rsp_get_model.Retcode())
  203. sys.exit()
  204. for i in range(0, 1):
  205. print(rsp_get_model.FeatureMap(i).WeightFullname())
  206. origin = update_model_data[weight_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
  207. after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
  208. print("Before update model", args.pid, origin[0:10])
  209. print("After get model", args.pid, after[0:10])
  210. sys.stdout.flush()
  211. get_model_result['reason'] = "Success"
  212. get_model_result['next_ts'] = 0
  213. return get_model_result
  214. while True:
  215. result, current_iteration = start_fl_job()
  216. sys.stdout.flush()
  217. if result['reason'] == "Restart iteration.":
  218. current_ts = datetime_to_timestamp(datetime.datetime.now())
  219. duration = result['next_ts'] - current_ts
  220. if duration >= 0:
  221. time.sleep(duration / 1000)
  222. continue
  223. result, update_data = update_model(current_iteration)
  224. sys.stdout.flush()
  225. if result['reason'] == "Restart iteration.":
  226. current_ts = datetime_to_timestamp(datetime.datetime.now())
  227. duration = result['next_ts'] - current_ts
  228. if duration >= 0:
  229. time.sleep(duration / 1000)
  230. continue
  231. result = get_model(current_iteration, update_data)
  232. sys.stdout.flush()
  233. if result['reason'] == "Restart iteration.":
  234. current_ts = datetime_to_timestamp(datetime.datetime.now())
  235. duration = result['next_ts'] - current_ts
  236. if duration >= 0:
  237. time.sleep(duration / 1000)
  238. continue
  239. print("")
  240. sys.stdout.flush()