diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc index 22d0f1132d..08a9fc9f9e 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.cc @@ -77,6 +77,24 @@ bool GPUKernelRuntime::Init() { } namespace { + +std::vector CheckRealOutput(const std::string &node_name, const size_t &output_size) { + // define a vector containing real output number + std::vector real_outputs; + // P.FusedBatchNorm is used for training; P.BatchNorm is used for inference + // can add the filter list for more operators here.... + if (node_name == "FusedBatchNorm") { + MS_LOG(INFO) << "loading node named FusedBatchNorm."; + real_outputs.insert(real_outputs.end(), {0, 3, 4}); + } else { + // by default, TensorLoader will load all outputs + for (size_t j = 0; j < output_size; ++j) { + real_outputs.push_back(j); + } + } + return real_outputs; +} + void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, const std::vector &kernel_inputs, const std::vector &kernel_workspaces, @@ -125,7 +143,13 @@ void LoadKernelData(Debugger *debugger, const CNodePtr &kernel, // get outputs auto output_size = AnfAlgo::GetOutputTensorNum(kernel); - for (size_t j = 0; j < output_size; ++j) { + auto node_name = AnfAlgo::GetCNodeName(kernel); + + std::vector real_outputs; + real_outputs = CheckRealOutput(node_name, output_size); + + for (std::vector::iterator it = real_outputs.begin(); it != real_outputs.end(); ++it) { + auto j = *it; auto addr = kernel_outputs[j]; auto type = AnfAlgo::GetOutputInferDataType(kernel, j); auto format = kOpFormat_DEFAULT;