You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

compare_profiling.py 1.5 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. def read_file(manual, auto):
  2. manual_prof = dict()
  3. auto_prof = dict()
  4. cur_op = ""
  5. def _process_line(line, is_manual):
  6. nonlocal cur_op
  7. prof = manual_prof if is_manual else auto_prof
  8. op_key = "Operater:"
  9. time_key = "gpu(0): exec="
  10. if op_key in line:
  11. cur_op = line.split(op_key)[1].replace("\n", "").strip(" ")
  12. if time_key in line:
  13. cur_time = line.split(time_key)[1].replace(
  14. "\n", "").replace("sec/op", "").strip(" ")
  15. if cur_op not in prof:
  16. prof[cur_op] = [cur_time]
  17. else:
  18. prof[cur_op].append(cur_time)
  19. with open(manual, "r") as f:
  20. for line in f:
  21. _process_line(line, True)
  22. with open(auto, "r") as f:
  23. for line in f:
  24. _process_line(line, False)
  25. compare(manual_prof, auto_prof)
  26. def compare(manual_prof, auto_prof):
  27. for k, m_times in manual_prof.items():
  28. a_times = auto_prof.get(k)
  29. if not a_times:
  30. print("Time for {} is not found in auto schedule".format(k))
  31. continue
  32. print("operator: {}".format(k))
  33. for m, a in zip(m_times, a_times):
  34. print("manual {} vs auto {}, manual / auto = {:.2f}%".format(m, a, float(m) / float(a) * 100))
  35. print("")
  36. if __name__ == '__main__':
  37. import sys
  38. if len(sys.argv) != 3:
  39. print("Usage: python compare_profiling.py manual_log_name auto_log_name")
  40. sys.exit()
  41. read_file(sys.argv[1], sys.argv[2])