You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

activation_test.cc 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <string>
  17. #include <list>
  18. #include <vector>
  19. #include "common/common_test.h"
  20. #include "frontend/parallel/strategy.h"
  21. #include "frontend/parallel/ops_info/activation_info.h"
  22. #include "frontend/parallel/device_manager.h"
  23. namespace mindspore {
  24. namespace parallel {
  25. class Activation;
  26. class Softmax;
  27. using ActivationPtr = std::shared_ptr<ActivationInfo>;
  28. using SoftmaxPtr = std::shared_ptr<Softmax>;
  29. ActivationPtr act_ptr_;
  30. SoftmaxPtr soft_ptr_;
  31. class TestActivation : public UT::Common {
  32. public:
  33. TestActivation() {}
  34. void SetUp();
  35. void TearDown() {}
  36. };
  37. void TestActivation::SetUp() {
  38. RankList dev_list;
  39. for (int32_t i = 0; i < 1050; i++) {
  40. dev_list.push_back(i);
  41. }
  42. RankList stage_map;
  43. stage_map.push_back(1024);
  44. stage_map.push_back(26);
  45. int32_t local_dev = 0;
  46. // create a new g_device_manager
  47. g_device_manager = std::make_shared<DeviceManager>();
  48. g_device_manager->Init(dev_list, local_dev, stage_map, "hccl");
  49. ValuePtr relu = MakeValue(std::string("relu"));
  50. std::unordered_map<std::string, ValuePtr> relu_attr = {{"activation_type", relu}};
  51. ValuePtr sm = MakeValue(std::string("softmax"));
  52. ValuePtr axix = MakeValue(std::int32_t(2));
  53. std::unordered_map<std::string, ValuePtr> softmax_attr = {{"activation_type", sm}, {"axis", axix}};
  54. Shapes relu_inputs_shape = {{2, 4, 8, 16}};
  55. Shapes relu_outputs_shape = {{2, 4, 8, 16}};
  56. Shapes sm_inputs_shape = {{8, 8, 8, 16}};
  57. Shapes sm_outputs_shape = {{8, 8, 8, 16}};
  58. act_ptr_ = std::make_shared<ActivationInfo>("relu_info", relu_inputs_shape, relu_outputs_shape, relu_attr);
  59. soft_ptr_ = std::make_shared<Softmax>("softmax_info", sm_inputs_shape, sm_outputs_shape, softmax_attr);
  60. }
  61. TEST_F(TestActivation, test_activation_strategies) {
  62. ASSERT_EQ(act_ptr_->GenerateStrategies(0), Status::SUCCESS);
  63. std::vector<std::shared_ptr<StrategyWithCost>> sc = act_ptr_->GetStrategyCost();
  64. for (const auto& swc : sc) {
  65. ASSERT_NE(swc, nullptr);
  66. ASSERT_GT(swc->cost_list.size(), 0);
  67. StrategyPtr sp = swc->strategy_ptr;
  68. ASSERT_NE(sp, nullptr);
  69. Cost cost = *(swc->cost_list[0]);
  70. act_ptr_->InitForCostModel(sp);
  71. std::vector<TensorInfo> inputs_info = act_ptr_->inputs_tensor_info();
  72. std::vector<TensorInfo> outputs_info = act_ptr_->outputs_tensor_info();
  73. ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
  74. cost.computation_cost_);
  75. ASSERT_DOUBLE_EQ(act_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
  76. cost.communication_cost_);
  77. }
  78. }
  79. TEST_F(TestActivation, test_softmax_strategies) {
  80. ASSERT_EQ(soft_ptr_->GenerateStrategies(0), Status::SUCCESS);
  81. std::vector<std::shared_ptr<StrategyWithCost>> sc = soft_ptr_->GetStrategyCost();
  82. for (const auto& swc : sc) {
  83. ASSERT_NE(swc, nullptr);
  84. ASSERT_GT(swc->cost_list.size(), 0);
  85. StrategyPtr sp = swc->strategy_ptr;
  86. ASSERT_NE(sp, nullptr);
  87. Cost cost = *(swc->cost_list[0]);
  88. Strategys stra = sp->GetInputDim();
  89. ASSERT_GT(stra.size(), 0);
  90. Dimensions input0_stra = stra[0];
  91. ASSERT_GT(input0_stra.size(), 2);
  92. ASSERT_EQ(input0_stra[2], 1);
  93. soft_ptr_->InitForCostModel(sp);
  94. std::vector<TensorInfo> inputs_info = soft_ptr_->inputs_tensor_info();
  95. std::vector<TensorInfo> outputs_info = soft_ptr_->outputs_tensor_info();
  96. ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetComputationCost(inputs_info, outputs_info, sp->GetInputStage()),
  97. cost.computation_cost_);
  98. ASSERT_DOUBLE_EQ(soft_ptr_->operator_cost()->GetCommCost(inputs_info, outputs_info, sp->GetInputStage()),
  99. cost.communication_cost_);
  100. }
  101. }
  102. } // namespace parallel
  103. } // namespace mindspore