From ea8c8361d67c7eb13e48fe4b0050dd0493497e13 Mon Sep 17 00:00:00 2001 From: lichen_101010 Date: Wed, 23 Sep 2020 17:36:15 -0400 Subject: [PATCH] add output filter for BatchNorm operator Add some comments addressed John's comments CI check CI check part2 --- .../runtime/device/gpu/gpu_kernel_runtime.cc | 26 ++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index c6c766e322..a6e6f60e29 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -76,6 +76,24 @@ bool GPUKernelRuntime::Init() { } namespace { + +std::vector CheckRealOutput(const std::string &node_name, const size_t &output_size) { + // define a vector containing real output number + std::vector real_outputs; + // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference + // can add the filter list for more operators here.... + if (node_name == "FusedBatchNorm") { + MS_LOG(INFO) << "loading node named FusedBatchNorm."; + real_outputs.insert(real_outputs.end(), {0, 3, 4}); + } else { + // by default, TensorLoader will load all outputs + for (size_t j = 0; j < output_size; ++j) { + real_outputs.push_back(j); + } + } + return real_outputs; +} + void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, const std::vector &kernel_inputs, const std::vector &kernel_workspaces, @@ -120,7 +138,13 @@ void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, // get outputs auto output_size = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t j = 0; j < output_size; ++j) { + auto node_name = AnfAlgo::GetCNodeName(kernel); + + std::vector real_outputs; + real_outputs = CheckRealOutput(node_name, output_size); + + for (std::vector::iterator it = real_outputs.begin(); it != real_outputs.end(); ++it) { + auto j = *it; auto addr = kernel_outputs[j]; auto type = AnfAlgo::GetOutputInferDataType(kernel, j); auto format = kOpFormat_DEFAULT;