|
|
|
@@ -16,6 +16,9 @@ |
|
|
|
|
|
|
|
#include "frontend/optimizer/irpass/less_batch_normalization.h" |
|
|
|
|
|
|
|
#include <set> |
|
|
|
#include <unordered_map> |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace opt { |
|
|
|
namespace irpass { |
|
|
|
@@ -302,7 +305,7 @@ bool IsRealRemoveParameterNode(const FuncGraphManagerPtr &manager, const AnfNode |
|
|
|
if (IsNotRealUseNode(node)) { |
|
|
|
const auto &cnode = node->cast<CNodePtr>(); |
|
|
|
const auto &new_cnode = ConvertRemoveNodeToVirtualNode(cnode); |
|
|
|
manager->Replace(cnode, new_cnode); |
|
|
|
(void)manager->Replace(cnode, new_cnode); |
|
|
|
continue; |
|
|
|
} |
|
|
|
need_remove = false; |
|
|
|
@@ -322,17 +325,18 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager |
|
|
|
MS_EXCEPTION_IF_NULL(root_graph); |
|
|
|
|
|
|
|
std::vector<AnfNodePtr> real_remove_parameter_list; |
|
|
|
std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(), |
|
|
|
std::back_inserter(real_remove_parameter_list), |
|
|
|
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); }); |
|
|
|
(void)std::copy_if(remove_parameter_list.begin(), remove_parameter_list.end(), |
|
|
|
std::back_inserter(real_remove_parameter_list), |
|
|
|
[&manager](const AnfNodePtr ¶m) { return IsRealRemoveParameterNode(manager, param); }); |
|
|
|
|
|
|
|
auto root_parameters = root_graph->parameters(); |
|
|
|
size_t origin_param_count = root_parameters.size(); |
|
|
|
root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(), |
|
|
|
[&real_remove_parameter_list](const AnfNodePtr &node) { |
|
|
|
return NeedRemove(node->cast<ParameterPtr>(), real_remove_parameter_list); |
|
|
|
}), |
|
|
|
root_parameters.end()); |
|
|
|
(void)root_parameters.erase(std::remove_if(root_parameters.begin(), root_parameters.end(), |
|
|
|
[&real_remove_parameter_list](const AnfNodePtr &node) { |
|
|
|
return NeedRemove(node->cast<ParameterPtr>(), |
|
|
|
real_remove_parameter_list); |
|
|
|
}), |
|
|
|
root_parameters.end()); |
|
|
|
size_t remove_param_count = origin_param_count - root_parameters.size(); |
|
|
|
size_t hyper_param_count = root_graph->hyper_param_count(); |
|
|
|
if (remove_param_count > hyper_param_count) { |
|
|
|
@@ -346,12 +350,12 @@ void RemoveBatchNormalizetionNotUseParameters(const FuncGraphManagerPtr &manager |
|
|
|
} // namespace |
|
|
|
|
|
|
|
bool LessBatchNormalization::MatchStructureNode(const CNodePtr &cnode, const int32_t index, |
|
|
|
const kStructureTuple &patternTuple) { |
|
|
|
const kStructureTuple &patternTuple) const { |
|
|
|
if (index < 0) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
const auto &use_pattern = std::get<1>(patternTuple); |
|
|
|
int32_t use_index = index % use_pattern.size(); |
|
|
|
int32_t use_index = index % static_cast<int32_t>(use_pattern.size()); |
|
|
|
if (!IsPrimitiveCNode(cnode, use_pattern[use_index])) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
@@ -391,7 +395,7 @@ void LessBatchNormalization::IsRemoveNode(const CNodePtr &cnode, const std::vect |
|
|
|
} |
|
|
|
const auto &start_end_pair = std::get<2>(match_pattern.at(match_branch_)); |
|
|
|
if (match_node_ >= start_end_pair.first && match_node_ <= start_end_pair.second) { |
|
|
|
remove_node_list_.insert(cnode); |
|
|
|
(void)remove_node_list_.insert(cnode); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
@@ -408,7 +412,7 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con |
|
|
|
size_t sum_match_node = 0; |
|
|
|
std::for_each(current_pattern.begin(), current_pattern.end(), [&](const kStructureTuple &t) { |
|
|
|
sum_match_node += std::get<0>(t); |
|
|
|
total_match_node_.emplace_back(sum_match_node); |
|
|
|
(void)total_match_node_.emplace_back(sum_match_node); |
|
|
|
}); |
|
|
|
auto cnode = node->cast<CNodePtr>(); |
|
|
|
if (cnode == nullptr || cnode->inputs().empty()) { |
|
|
|
@@ -434,16 +438,16 @@ AnfNodePtr LessBatchNormalization::operator()(const OptimizerPtr &optimizer, con |
|
|
|
for (auto &iter : remove_node_list_) { |
|
|
|
// Need to remove batchnorm's parameter input. |
|
|
|
if (IsPrimitiveCNode(iter, prim::kPrimBatchNorm)) { |
|
|
|
std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(), |
|
|
|
std::back_inserter(remove_load_list), |
|
|
|
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); }); |
|
|
|
std::transform( |
|
|
|
(void)std::copy_if(iter->inputs().begin() + kBNParametersStartIndex, iter->inputs().end(), |
|
|
|
std::back_inserter(remove_load_list), |
|
|
|
[](const AnfNodePtr &node) { return IsPrimitiveCNode(node, prim::kPrimLoad); }); |
|
|
|
(void)std::transform( |
|
|
|
remove_load_list.begin(), remove_load_list.end(), std::back_inserter(remove_parameter_list), |
|
|
|
[](const AnfNodePtr &node) { return node->cast<CNodePtr>()->input(kValidResidualStructureIndex); }); |
|
|
|
} |
|
|
|
// Remove useless node. |
|
|
|
auto input_cnode = iter->input(kValidResidualStructureIndex); |
|
|
|
manager->Replace(iter, input_cnode); |
|
|
|
(void)manager->Replace(iter, input_cnode); |
|
|
|
} |
|
|
|
RemoveBatchNormalizetionNotUseParameters(manager, remove_parameter_list); |
|
|
|
|
|
|
|
@@ -471,7 +475,7 @@ void LessBatchNormalization::Visit(const CNodePtr &cnode) { |
|
|
|
void LessBatchNormalization::Reset() { |
|
|
|
remove_node_list_.clear(); |
|
|
|
total_match_node_.clear(); |
|
|
|
total_match_node_.emplace_back(0); |
|
|
|
(void)total_match_node_.emplace_back(0); |
|
|
|
match_node_ = 0; |
|
|
|
match_branch_ = 0; |
|
|
|
is_match_ = false; |
|
|
|
|