You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

process_data.py 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright 2021 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. import os
  16. import sys
  17. import json
  18. import matplotlib.ticker as ticker
  19. import matplotlib.pyplot as plt
  20. import openpyxl as opx
  21. def parse_arguments():
  22. log_path = sys.argv[1]
  23. log_data = sys.argv[2]
  24. me_report = sys.argv[3]
  25. n_days = sys.argv[4]
  26. assert n_days.isdigit()
  27. return log_path, log_data, me_report, int(n_days)
  28. def read_data(log_data, me_report_path, n_days):
  29. with open(log_data) as f:
  30. log = json.load(f)
  31. wb = opx.load_workbook(me_report_path)
  32. sheet = wb["Sheet"]
  33. n_row = sheet.max_row
  34. date = [cell[0].value for cell in sheet["A2":"A%d" % n_row]]
  35. reid_data = [float(cell[0].value) for cell in sheet["B2":"B%d" % n_row]]
  36. bert_data = [float(cell[0].value) for cell in sheet["C2":"C%d" % n_row]]
  37. resnet_data = [float(cell[0].value) for cell in sheet["D2":"D%d" % n_row]]
  38. gpt_data = [float(cell[0].value) for cell in sheet["E43":"E%d" % n_row]]
  39. if n_days > 0:
  40. date = date[-n_days:]
  41. reid_data = reid_data[-n_days:]
  42. bert_data = bert_data[-n_days:]
  43. resnet_data = resnet_data[-n_days:]
  44. gpt_data = gpt_data[-n_days:]
  45. return log, date, reid_data, bert_data, resnet_data, gpt_data
  46. def draw_figure(x_data, y_data, labels, title, out, height=24, width=8, tick_space=2):
  47. print("Generating figure to: %s" % out)
  48. plt.figure(figsize=(height, width))
  49. for y, label in zip(y_data, labels):
  50. x = x_data[-len(y):]
  51. n_data = len(x)
  52. assert len(x) == len(
  53. y), "assume len(x) == len(y), while %d != %d" % (len(x), len(y))
  54. plt.plot(x, y, linewidth=2, marker='o', markersize=5, label=label)
  55. ax = plt.gca()
  56. ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_space))
  57. for i in range(n_data):
  58. if i % 2 == 0:
  59. plt.text(x[i], y[i], y[i], ha='center',
  60. va='bottom', fontsize=8)
  61. plt.title(title)
  62. plt.xlabel("Date")
  63. plt.ylabel("Time(s)")
  64. plt.grid()
  65. plt.legend()
  66. plt.savefig(out)
  67. def generate_report(log, labels, log_path):
  68. for label in labels:
  69. fname = log[label]["min_file"]
  70. fname_path = os.path.join(log_path, fname)
  71. out_path = os.path.join(log_path, "reports", label+"_me.log")
  72. print("Generating report to: %s" % out_path)
  73. os.system("grep -A 230 'TotalTime = ' %s > %s" %
  74. (fname_path, out_path))
  75. def process_data():
  76. log_path, log_data, me_report, n_days = parse_arguments()
  77. log, date, reid_data, bert_data, resnet_data, gpt_data = read_data(
  78. log_data, me_report, n_days)
  79. draw_figure(date,
  80. [reid_data, bert_data, gpt_data],
  81. ["ReID", "BERT", "GPT"],
  82. "ReID&BERT&GPT",
  83. os.path.join(log_path, "reports", "reid_bert_gpt.png")
  84. )
  85. draw_figure(date, [resnet_data], ["ResNet"], "ResNet",
  86. os.path.join(log_path, "reports", "resnet.png"))
  87. generate_report(log, list(log.keys()), log_path)
  88. if __name__ == "__main__":
  89. process_data()