diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index 5c7927a321..cde22769ea 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -56,9 +56,9 @@ class BroadcastOpGpuKernel : public GpuKernel { } bool Init(const CNodePtr &kernel_node) override { GetOpType(kernel_node); - auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); - auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); + auto shape1 = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + auto shape2 = AnfAlgo::GetInputDeviceShape(kernel_node, 1); + auto shape3 = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); need_broadcast_ = IsBroadcast(shape1, shape2); if (need_broadcast_ && shape1.size() > 7) { MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";