From ea8c8361d67c7eb13e48fe4b0050dd0493497e13 Mon Sep 17 00:00:00 2001
From: lichen_101010
Date: Wed, 23 Sep 2020 17:36:15 -0400
Subject: [PATCH] add output filter for BatchNorm operator
Add some comments
addressed John's comments
CI check
CI check part2
---
.../runtime/device/gpu/gpu_kernel_runtime.cc | 26 ++++++++++++++++++-
1 file changed, 25 insertions(+), 1 deletion(-)
diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
index c6c766e322..a6e6f60e29 100644
--- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
+++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc
@@ -76,6 +76,24 @@ bool GPUKernelRuntime::Init() {
}
namespace {
+
+std::vector CheckRealOutput(const std::string &node_name, const size_t &output_size) {
+ // define a vector containing real output number
+ std::vector real_outputs;
+ // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference
+ // can add the filter list for more operators here....
+ if (node_name == "FusedBatchNorm") {
+ MS_LOG(INFO) << "loading node named FusedBatchNorm.";
+ real_outputs.insert(real_outputs.end(), {0, 3, 4});
+ } else {
+ // by default, TensorLoader will load all outputs
+ for (size_t j = 0; j < output_size; ++j) {
+ real_outputs.push_back(j);
+ }
+ }
+ return real_outputs;
+}
+
void LoadKernelData(Debugger *debugger, const CNodePtr &kernel,
const std::vector &kernel_inputs,
const std::vector &kernel_workspaces,
@@ -120,7 +138,13 @@ void LoadKernelData(Debugger *debugger, const CNodePtr &kernel,
// get outputs
auto output_size = AnfAlgo::GetOutputTensorNum(kernel);
- for (size_t j = 0; j < output_size; ++j) {
+ auto node_name = AnfAlgo::GetCNodeName(kernel);
+
+ std::vector real_outputs;
+ real_outputs = CheckRealOutput(node_name, output_size);
+
+ for (std::vector::iterator it = real_outputs.begin(); it != real_outputs.end(); ++it) {
+ auto j = *it;
auto addr = kernel_outputs[j];
auto type = AnfAlgo::GetOutputInferDataType(kernel, j);
auto format = kOpFormat_DEFAULT;