You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

split_strategy.h 3.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <vector>
  17. #include <string>
  18. #include <set>
  19. #include <utility>
  20. #include <map>
  21. #include <unordered_map>
  22. #include "schema/ops_generated.h"
  23. #include "base/core_ops.h"
  24. #include "include/lite_types.h"
  25. #ifndef MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
  26. #define MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_
  27. namespace mindspore {
  28. namespace opt {
  29. constexpr auto PARALLEL_NAME_SUFFIX = "_parallel";
  30. constexpr auto kParallelPrimitiveIndex = 0;
  31. const std::vector<int64_t> kSplitDefaultRatio = {0, 0};
  32. // user's device to split, only split to cpu && gpu, no support npu
  33. const std::vector<std::string> kSplitDevTypes = {"cpu", "gpu"};
  34. using Strategys = std::vector<std::vector<std::vector<int64_t>>>;
  35. constexpr auto kDeviceTypeNone = -1;
  36. // strategy format is NHWC-KHWC
  37. constexpr int32_t kAxisN = 0;
  38. constexpr int32_t kAxisCIn = 3;
  39. constexpr int32_t kAxisCOut = 0;
  40. constexpr int32_t kAxisH = 1;
  41. constexpr int32_t kAxisW = 2;
  42. constexpr auto kDefaultBatch = 1;
  43. constexpr auto kShapeN = 0;
  44. constexpr auto kShapeH = 1;
  45. constexpr auto kShapeW = 2;
  46. constexpr auto kShapeC = 3;
  47. constexpr auto kIndexH = 0;
  48. constexpr auto kIndexW = 1;
  49. constexpr auto kPadUp = 0;
  50. constexpr auto kPadDown = 1;
  51. constexpr auto kPadLeft = 2;
  52. constexpr auto kPadRight = 3;
  53. enum SplitMode {
  54. NoSplit = 0,
  55. SplitN = 1,
  56. SplitH = 2,
  57. SplitCIN = 3,
  58. SplitCOUT = 4,
  59. };
  60. struct SplitStrategy {
  61. Strategys strategys;
  62. std::vector<std::string> dev_types;
  63. size_t dev_num;
  64. SplitMode split_mode_;
  65. };
  66. // this is a map for key: <primitive,is_depth_wise> value: parallel_op_name
  67. const std::map<std::pair<PrimitivePtr, bool>, std::string> kParallelOpNames = {
  68. {{prim::kPrimConv2D, false}, "Conv2D"},
  69. {{prim::kPrimConv2DFusion, false}, "Conv2D"},
  70. {{prim::kPrimConv2D, true}, "DepthwiseConv2D"},
  71. {{prim::kPrimConv2DFusion, true}, "DepthwiseConv2D"}};
  72. const std::map<std::string, lite::DeviceType> kSupportSplitedDevices = {
  73. {"cpu", lite::DeviceType::DT_CPU}, {"gpu", lite::DeviceType::DT_GPU}, {"npu", lite::DeviceType::DT_NPU}};
  74. // this is a map for key: primitive value: schema_primitive_id
  75. const std::unordered_map<PrimitivePtr, std::pair<schema::PrimitiveType, TypeId>> kParallelSchemaId = {
  76. {prim::kPrimConv2D, {schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32}},
  77. {prim::kPrimConv2DFusion, {schema::PrimitiveType_Conv2DFusion, kNumberTypeFloat32}}};
  78. // this is an artificial restriction that if user split conv, we limit total FLOPs bigger than
  79. // 2 * output_H * output_W * (in_C * kW * kH +1) * out_C >= 100
  80. // FLOPs ~= output_H * output_W * (in_C * kW * kH) * out_C
  81. // FLOPs ~= (input_h/stride_h)*(input_w/stride_w)*in_C * kW * kH) * out_C
  82. // etc. (12/1)*(12/1)*(1*3*3)*128/1024 = 162kFLPOs
  83. constexpr auto kUserFLOPs = 100;
  84. constexpr auto kPerFlops = 1024;
  85. int64_t ApproximateFLOPs(const std::vector<int64_t> &strides, const std::vector<int64_t> &input_shae,
  86. const std::vector<int64_t> &weight_shape);
  87. std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(
  88. const std::vector<int64_t> &parallel_compute_rates, const std::vector<std::string> &parallel_devices,
  89. SplitMode split_mode);
  90. } // namespace opt
  91. } // namespace mindspore
  92. #endif // MINDSPORE_LITE_SRC_PASS_PARALLEL_SPLIT_STRATEGY_H_