You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

train.cc 3.7 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "coder/train.h"
  17. #include <memory>
  18. #include <set>
  19. #include <array>
  20. #include <queue>
  21. #include <string>
  22. #include <vector>
  23. #include <algorithm>
  24. #include "schema/ops_generated.h"
  25. #include "src/common/prim_util.h"
  26. namespace mindspore::lite::micro {
  27. std::set<OperatorCoder *> FindInferenceOpcoders(OperatorCoder *edge) {
  28. std::set<OperatorCoder *> subgraph;
  29. std::queue<OperatorCoder *> to_visit;
  30. to_visit.push(edge);
  31. while (!to_visit.empty()) {
  32. size_t size = to_visit.size();
  33. for (size_t i = 0; i < size; ++i) {
  34. OperatorCoder *curr = to_visit.front();
  35. to_visit.pop();
  36. if (subgraph.find(curr) != subgraph.end()) {
  37. continue;
  38. }
  39. subgraph.insert(curr);
  40. for (const auto &op : curr->input_ops()) {
  41. to_visit.push(op);
  42. }
  43. }
  44. }
  45. auto item = subgraph.find(edge);
  46. if (item == subgraph.end()) {
  47. MS_LOG(ERROR) << "failed to find the edge in the subgraph";
  48. return subgraph;
  49. }
  50. // erase edge operator coder from subgraph
  51. subgraph.erase(item);
  52. return subgraph;
  53. }
  54. int Train::TransformGraphForTrain(CoderContext *context, const std::vector<std::unique_ptr<OperatorCoder>> &op_coders,
  55. int schema_version) {
  56. if (context == nullptr) {
  57. MS_LOG(INFO) << "input context invalid";
  58. return RET_ERROR;
  59. }
  60. const std::array<int, 6> loss_types = {schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits,
  61. schema::PrimitiveType_BinaryCrossEntropy,
  62. schema::PrimitiveType_SmoothL1Loss,
  63. schema::PrimitiveType_SmoothL1LossGrad,
  64. schema::PrimitiveType_SigmoidCrossEntropyWithLogits,
  65. schema::PrimitiveType_SigmoidCrossEntropyWithLogitsGrad};
  66. OperatorCoder *loss_op = nullptr;
  67. for (const auto &opcoder : op_coders) {
  68. const Model::Node *node = opcoder->node();
  69. int primitive_type = GetPrimitiveType(node->primitive_, schema_version);
  70. auto item = std::find(loss_types.begin(), loss_types.end(), primitive_type);
  71. if (item != loss_types.end()) {
  72. loss_op = opcoder.get();
  73. break;
  74. }
  75. }
  76. MS_CHECK_PTR(loss_op);
  77. size_t op_num = op_coders.size();
  78. std::vector<std::string> code_blocks = context->code_blocks();
  79. if (op_num != code_blocks.size()) {
  80. MS_LOG(INFO) << "the number of code blocks and op coders is not equal";
  81. return RET_ERROR;
  82. }
  83. std::set<OperatorCoder *> inference_ops = FindInferenceOpcoders(loss_op);
  84. std::vector<std::string> inferences_blocks;
  85. std::vector<std::string> train_blocks;
  86. for (size_t i = 0; i < op_num; ++i) {
  87. auto &opcoder = op_coders.at(i);
  88. std::string block = code_blocks.at(i);
  89. if (inference_ops.find(opcoder.get()) != inference_ops.end()) {
  90. inferences_blocks.push_back(block);
  91. }
  92. train_blocks.push_back(block);
  93. }
  94. context->set_inference_blocks(inferences_blocks);
  95. context->set_train_blocks(train_blocks);
  96. return RET_OK;
  97. }
  98. } // namespace mindspore::lite::micro