You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

split_strategy.cc 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "tools/optimizer/parallel/split_strategy.h"
  17. #include <vector>
  18. #include <unordered_map>
  19. #include <string>
  20. namespace mindspore {
  21. namespace opt {
  22. std::unordered_map<std::string, opt::SplitStrategy> ParserSplitStrategy(SplitMode parallel_mode) {
  23. std::unordered_map<std::string, opt::SplitStrategy> split_strategys;
  24. if (kSplitRatio.empty() || kSplitDefaultRatio.empty() || kSplitDevTypes.empty()) {
  25. return split_strategys;
  26. }
  27. if (kSplitRatio.size() != kSplitDevTypes.size()) {
  28. return split_strategys;
  29. }
  30. std::vector<std::vector<int64_t>> split_feature_map;
  31. std::vector<std::vector<int64_t>> split_weight;
  32. switch (parallel_mode) {
  33. case SplitN:
  34. split_feature_map = {kSplitRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
  35. split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
  36. break;
  37. case SplitH:
  38. split_feature_map = {kSplitDefaultRatio, kSplitRatio, kSplitDefaultRatio, kSplitDefaultRatio};
  39. split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
  40. break;
  41. case SplitCIN:
  42. split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitRatio};
  43. split_weight = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitRatio};
  44. break;
  45. case SplitCOUT:
  46. split_feature_map = {kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
  47. split_weight = {kSplitRatio, kSplitDefaultRatio, kSplitDefaultRatio, kSplitDefaultRatio};
  48. break;
  49. default:
  50. return split_strategys;
  51. }
  52. opt::Strategys strategys = {split_feature_map, split_weight};
  53. split_strategys[opt::kSplitOp] = {strategys, kSplitDevTypes, kSplitDevTypes.size()};
  54. return split_strategys;
  55. }
  56. } // namespace opt
  57. } // namespace mindspore