|
|
|
@@ -56,9 +56,9 @@ class BroadcastOpGpuKernel : public GpuKernel { |
|
|
|
} |
|
|
|
bool Init(const CNodePtr &kernel_node) override { |
|
|
|
GetOpType(kernel_node); |
|
|
|
auto shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); |
|
|
|
auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
auto shape1 = AnfAlgo::GetInputDeviceShape(kernel_node, 0); |
|
|
|
auto shape2 = AnfAlgo::GetInputDeviceShape(kernel_node, 1); |
|
|
|
auto shape3 = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); |
|
|
|
need_broadcast_ = IsBroadcast(shape1, shape2); |
|
|
|
if (need_broadcast_ && shape1.size() > 7) { |
|
|
|
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7"; |
|
|
|
|