You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

simulator.py 8.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import argparse
  16. import time
  17. import random
  18. import sys
  19. import requests
  20. import flatbuffers
  21. import numpy as np
  22. from mindspore.schema import (RequestFLJob, ResponseFLJob, ResponseCode,
  23. RequestUpdateModel, FeatureMap, RequestGetModel, ResponseGetModel)
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--pid", type=int, default=0)
  26. parser.add_argument("--http_ip", type=str, default="10.113.216.106")
  27. parser.add_argument("--http_port", type=int, default=6666)
  28. parser.add_argument("--use_elb", type=bool, default=False)
  29. parser.add_argument("--server_num", type=int, default=1)
  30. args, _ = parser.parse_known_args()
  31. pid = args.pid
  32. http_ip = args.http_ip
  33. http_port = args.http_port
  34. use_elb = args.use_elb
  35. server_num = args.server_num
  36. str_fl_id = 'fl_lenet_' + str(pid)
  37. def generate_port():
  38. if not use_elb:
  39. return http_port
  40. port = random.randint(0, 100000) % server_num + http_port
  41. return port
  42. def build_start_fl_job(iteration):
  43. start_fl_job_builder = flatbuffers.Builder(1024)
  44. fl_name = start_fl_job_builder.CreateString('fl_test_job')
  45. fl_id = start_fl_job_builder.CreateString(str_fl_id)
  46. data_size = 32
  47. timestamp = start_fl_job_builder.CreateString('2020/11/16/19/18')
  48. RequestFLJob.RequestFLJobStart(start_fl_job_builder)
  49. RequestFLJob.RequestFLJobAddFlName(start_fl_job_builder, fl_name)
  50. RequestFLJob.RequestFLJobAddFlId(start_fl_job_builder, fl_id)
  51. RequestFLJob.RequestFLJobAddIteration(start_fl_job_builder, iteration)
  52. RequestFLJob.RequestFLJobAddDataSize(start_fl_job_builder, data_size)
  53. RequestFLJob.RequestFLJobAddTimestamp(start_fl_job_builder, timestamp)
  54. fl_job_req = RequestFLJob.RequestFLJobEnd(start_fl_job_builder)
  55. start_fl_job_builder.Finish(fl_job_req)
  56. buf = start_fl_job_builder.Output()
  57. return buf
  58. def build_feature_map(builder, names, lengths):
  59. if len(names) != len(lengths):
  60. return None
  61. feature_maps = []
  62. np_data = []
  63. for j, _ in enumerate(names):
  64. name = names[j]
  65. length = lengths[j]
  66. weight_full_name = builder.CreateString(name)
  67. FeatureMap.FeatureMapStartDataVector(builder, length)
  68. weight = np.random.rand(length) * 32
  69. np_data.append(weight)
  70. for idx in range(length - 1, -1, -1):
  71. builder.PrependFloat32(weight[idx])
  72. data = builder.EndVector(length)
  73. FeatureMap.FeatureMapStart(builder)
  74. FeatureMap.FeatureMapAddData(builder, data)
  75. FeatureMap.FeatureMapAddWeightFullname(builder, weight_full_name)
  76. feature_map = FeatureMap.FeatureMapEnd(builder)
  77. feature_maps.append(feature_map)
  78. return feature_maps, np_data
  79. def build_update_model(iteration):
  80. builder_update_model = flatbuffers.Builder(1)
  81. fl_name = builder_update_model.CreateString('fl_test_job')
  82. fl_id = builder_update_model.CreateString(str_fl_id)
  83. timestamp = builder_update_model.CreateString('2020/11/16/19/18')
  84. feature_maps, np_data = build_feature_map(builder_update_model,
  85. ["conv1.weight", "conv2.weight", "fc1.weight",
  86. "fc2.weight", "fc3.weight", "fc1.bias", "fc2.bias", "fc3.bias"],
  87. [450, 2400, 48000, 10080, 5208, 120, 84, 62])
  88. RequestUpdateModel.RequestUpdateModelStartFeatureMapVector(builder_update_model, 1)
  89. for single_feature_map in feature_maps:
  90. builder_update_model.PrependUOffsetTRelative(single_feature_map)
  91. feature_map = builder_update_model.EndVector(len(feature_maps))
  92. RequestUpdateModel.RequestUpdateModelStart(builder_update_model)
  93. RequestUpdateModel.RequestUpdateModelAddFlName(builder_update_model, fl_name)
  94. RequestUpdateModel.RequestUpdateModelAddFlId(builder_update_model, fl_id)
  95. RequestUpdateModel.RequestUpdateModelAddIteration(builder_update_model, iteration)
  96. RequestUpdateModel.RequestUpdateModelAddFeatureMap(builder_update_model, feature_map)
  97. RequestUpdateModel.RequestUpdateModelAddTimestamp(builder_update_model, timestamp)
  98. req_update_model = RequestUpdateModel.RequestUpdateModelEnd(builder_update_model)
  99. builder_update_model.Finish(req_update_model)
  100. buf = builder_update_model.Output()
  101. return buf, np_data
  102. def build_get_model(iteration):
  103. builder_get_model = flatbuffers.Builder(1)
  104. fl_name = builder_get_model.CreateString('fl_test_job')
  105. timestamp = builder_get_model.CreateString('2020/12/16/19/18')
  106. RequestGetModel.RequestGetModelStart(builder_get_model)
  107. RequestGetModel.RequestGetModelAddFlName(builder_get_model, fl_name)
  108. RequestGetModel.RequestGetModelAddIteration(builder_get_model, iteration)
  109. RequestGetModel.RequestGetModelAddTimestamp(builder_get_model, timestamp)
  110. req_get_model = RequestGetModel.RequestGetModelEnd(builder_get_model)
  111. builder_get_model.Finish(req_get_model)
  112. buf = builder_get_model.Output()
  113. return buf
  114. weight_name_to_idx = {
  115. "conv1.weight": 0,
  116. "conv2.weight": 1,
  117. "fc1.weight": 2,
  118. "fc2.weight": 3,
  119. "fc3.weight": 4,
  120. "fc1.bias": 5,
  121. "fc2.bias": 6,
  122. "fc3.bias": 7
  123. }
  124. session = requests.Session()
  125. current_iteration = 1
  126. url = "http://" + http_ip + ":" + str(generate_port())
  127. np.random.seed(0)
  128. while True:
  129. url1 = "http://" + http_ip + ":" + str(generate_port()) + '/startFLJob'
  130. print("start url is ", url1)
  131. x = requests.post(url1, data=build_start_fl_job(current_iteration))
  132. rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
  133. print("start fl job iteration:", current_iteration, ", id:", args.pid)
  134. while rsp_fl_job.Retcode() != ResponseCode.ResponseCode.SUCCEED:
  135. x = requests.post(url1, data=build_start_fl_job(current_iteration))
  136. rsp_fl_job = rsp_fl_job = ResponseFLJob.ResponseFLJob.GetRootAsResponseFLJob(x.content, 0)
  137. print("epoch is", rsp_fl_job.FlPlanConfig().Epochs())
  138. sys.stdout.flush()
  139. url2 = "http://" + http_ip + ":" + str(generate_port()) + '/updateModel'
  140. print("req update model iteration:", current_iteration, ", id:", args.pid)
  141. update_model_buf, update_model_np_data = build_update_model(current_iteration)
  142. x = session.post(url2, data=update_model_buf)
  143. print("rsp update model iteration:", current_iteration, ", id:", args.pid)
  144. sys.stdout.flush()
  145. url3 = "http://" + http_ip + ":" + str(generate_port()) + '/getModel'
  146. print("req get model iteration:", current_iteration, ", id:", args.pid)
  147. x = session.post(url3, data=build_get_model(current_iteration))
  148. rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
  149. print("rsp get model iteration:", current_iteration, ", id:", args.pid, rsp_get_model.Retcode())
  150. sys.stdout.flush()
  151. repeat_time = 0
  152. while rsp_get_model.Retcode() == ResponseCode.ResponseCode.SucNotReady:
  153. time.sleep(0.1)
  154. x = session.post(url3, data=build_get_model(current_iteration))
  155. rsp_get_model = ResponseGetModel.ResponseGetModel.GetRootAsResponseGetModel(x.content, 0)
  156. repeat_time += 1
  157. if repeat_time > 1000:
  158. print("GetModel try timeout ", args.pid)
  159. sys.exit(0)
  160. for i in range(0, 1):
  161. print(rsp_get_model.FeatureMap(i).WeightFullname())
  162. origin = update_model_np_data[weight_name_to_idx[rsp_get_model.FeatureMap(i).WeightFullname().decode('utf-8')]]
  163. after = rsp_get_model.FeatureMap(i).DataAsNumpy() * 32
  164. print("Before update model", args.pid, origin[0:10])
  165. print("After get model", args.pid, after[0:10])
  166. sys.stdout.flush()
  167. assert np.allclose(origin, after, rtol=1e-05, atol=1e-05)
  168. current_iteration += 1