You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.py 1.7 kB

5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import logging
  16. logging.getLogger().setLevel(logging.INFO)
  17. def completeJson(dict_json_origin, inputs, attr):
  18. dict_json = dict_json_origin
  19. if len(inputs) == 1:
  20. dict_json["input_desc"] = inputs
  21. else:
  22. list_list = []
  23. for i in range(len(inputs)):
  24. list_list.append([inputs[i]])
  25. dict_json["input_desc"] = list_list
  26. dict_json["attr"] = attr
  27. dict_json["op"] = dict_json["name"] + "_test"
  28. input_index = 0
  29. for input_desc_elem in dict_json["input_desc"]:
  30. dict_json["input_desc"][input_index][0]["tensor_name"] = "input_" + dict_json["input_desc"][input_index][0][
  31. "name"]
  32. input_index += 1
  33. return dict_json
  34. def cce_run(run_func, *args, **kwargs):
  35. module = args[-1]['mod']
  36. if module == None:
  37. logging.info("build module failed")
  38. sys.exit(1)
  39. else:
  40. input, output, expect, runres = run_func(*args, **kwargs)
  41. if not runres:
  42. logging.info("test result wrong")
  43. sys.exit(1)
  44. else:
  45. logging.info("result right")