You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

subgraph.cpp 3.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. /**
  2. * \file imperative/src/impl/subgraph.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "megbrain/imperative/subgraph.h"
  13. namespace mgb {
  14. namespace imperative {
  15. void Subgraph::remove_unused_exprs() {
  16. std::unordered_set<size_t> required_vars = {outputs.begin(), outputs.end()};
  17. required_vars.erase(0);
  18. for (auto iter = exprs.rbegin(); iter != exprs.rend(); ++iter) {
  19. auto& expr = *iter;
  20. bool required = false;
  21. for (auto output : expr.outputs) {
  22. if (required_vars.count(output)) {
  23. required = true;
  24. break;
  25. }
  26. }
  27. if (required) {
  28. required_vars.insert(expr.inputs.begin(), expr.inputs.end());
  29. } else {
  30. expr.op = nullptr;
  31. }
  32. }
  33. exprs.erase(
  34. std::remove_if(
  35. exprs.begin(), exprs.end(),
  36. [](auto expr) { return expr.op == nullptr; }),
  37. exprs.end());
  38. }
  39. SmallVector<bool> Subgraph::gen_input_mask() {
  40. std::unordered_set<size_t> unused_inputs = {inputs.begin(), inputs.end()};
  41. for (auto&& expr : exprs) {
  42. for (auto&& input : expr.inputs) {
  43. unused_inputs.erase(input);
  44. }
  45. }
  46. for (auto&& output : outputs) {
  47. unused_inputs.erase(output);
  48. }
  49. unused_inputs.insert(0);
  50. SmallVector<bool> mask(inputs.size(), true);
  51. for (size_t i = 0; i < inputs.size(); ++i) {
  52. if (unused_inputs.count(inputs[i])) {
  53. mask[i] = false;
  54. }
  55. }
  56. return mask;
  57. }
  58. SmallVector<bool> Subgraph::gen_output_mask() {
  59. std::unordered_set<size_t> invalid_outputs = {outputs.begin(), outputs.end()};
  60. for (auto&& input : inputs) {
  61. invalid_outputs.erase(input);
  62. }
  63. for (auto&& expr : exprs) {
  64. for (auto&& output : expr.outputs) {
  65. invalid_outputs.erase(output);
  66. }
  67. }
  68. for (auto&& constant : constants) {
  69. invalid_outputs.erase(constant.first);
  70. }
  71. invalid_outputs.insert(0);
  72. SmallVector<bool> mask(outputs.size(), true);
  73. for (size_t i = 0; i < outputs.size(); ++i) {
  74. if (invalid_outputs.count(outputs[i])) {
  75. mask[i] = false;
  76. }
  77. }
  78. return mask;
  79. }
  80. void Subgraph::replace_vars(const std::unordered_map<size_t, size_t>& replace_map) {
  81. // FIXME: preprocess replace_map
  82. auto replace_var = [&](var_t& var) {
  83. // TODO: detect infinite loop
  84. while (replace_map.count(var)) {
  85. var = replace_map.at(var);
  86. }
  87. };
  88. for (auto& expr : exprs) {
  89. for (auto& input : expr.inputs) {
  90. replace_var(input);
  91. }
  92. }
  93. for (auto& output : outputs) {
  94. replace_var(output);
  95. }
  96. }
  97. std::string EncodedSubgraph::repr() const {
  98. std::string buffer;
  99. buffer.push_back('|');
  100. for (size_t i = 0; i < input_mask.size(); ++i) {
  101. buffer.push_back(input_mask[i] ? '#' : ' ');
  102. }
  103. buffer.push_back('|');
  104. buffer.push_back('\n');
  105. buffer.append(graph.repr());
  106. buffer.push_back('|');
  107. for (size_t i = 0; i < output_mask.size(); ++i) {
  108. buffer.push_back(output_mask[i] ? '#' : ' ');
  109. }
  110. buffer.push_back('|');
  111. return buffer;
  112. }
  113. size_t EncodedSubgraph::hash() const {
  114. return std::hash<std::string>{}(repr());
  115. }
  116. } // namespace imperative
  117. } // namespace mgb