|
|
|
@@ -89,8 +89,8 @@ std::vector<int> CheckRealOutput(const std::string &node_name, const size_t &out |
|
|
|
std::vector<int> real_outputs; |
|
|
|
// P.FusedBatchNorm is used for training; P.BatchNorm is used for inference |
|
|
|
// can add the filter list for more operators here.... |
|
|
|
if (node_name == "FusedBatchNorm") { |
|
|
|
MS_LOG(INFO) << "loading node named FusedBatchNorm."; |
|
|
|
if (node_name == "FusedBatchNorm" || node_name == "BatchNorm") { |
|
|
|
MS_LOG(INFO) << "loading node named " << node_name; |
|
|
|
real_outputs.insert(real_outputs.end(), {0, 3, 4}); |
|
|
|
} else { |
|
|
|
// by default, TensorLoader will load all outputs |
|
|
|
|