You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_converter.py 2.6 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright 2019 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. Fuction:
  17. Test mindconverter to convert user's PyTorch network script.
  18. Usage:
  19. pytest tests/st/func/mindconverter
  20. """
  21. import difflib
  22. import os
  23. import sys
  24. import pytest
  25. from mindinsight.mindconverter.converter import main
  26. @pytest.mark.usefixtures('create_output_dir')
  27. class TestConverter:
  28. """Test Converter module."""
  29. @classmethod
  30. def setup_class(cls):
  31. """Setup method."""
  32. cls.script_dir = os.path.join(os.path.dirname(__file__), 'data')
  33. sys.path.insert(0, cls.script_dir)
  34. @classmethod
  35. def teardown_class(cls):
  36. """Teardown method."""
  37. sys.path.remove(cls.script_dir)
  38. @pytest.mark.level0
  39. @pytest.mark.platform_arm_ascend_training
  40. @pytest.mark.platform_x86_gpu_training
  41. @pytest.mark.platform_x86_ascend_training
  42. @pytest.mark.platform_x86_cpu
  43. @pytest.mark.env_single
  44. def test_convert_lenet(self, output):
  45. """Test LeNet script of the PyTorch convert to MindSpore script"""
  46. script_filename = "lenet_script.py"
  47. expect_filename = "lenet_converted.py"
  48. files_config = {
  49. 'root_path': self.script_dir,
  50. 'in_files': [os.path.join(self.script_dir, script_filename)],
  51. 'outfile_dir': output,
  52. 'report_dir': output
  53. }
  54. main(files_config)
  55. assert os.path.isfile(os.path.join(output, script_filename))
  56. with open(os.path.join(output, script_filename)) as converted_f:
  57. converted_source = converted_f.readlines()
  58. with open(os.path.join(self.script_dir, expect_filename)) as expect_f:
  59. expect_source = expect_f.readlines()
  60. diff = difflib.ndiff(converted_source, expect_source)
  61. diff_lines = 0
  62. for line in diff:
  63. if line.startswith('+'):
  64. diff_lines += 1
  65. converted_ratio = 100 - (diff_lines * 100) / (len(expect_source))
  66. assert converted_ratio >= 80