|
|
|
@@ -25,14 +25,27 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); |
|
|
|
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); |
|
|
|
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); |
|
|
|
if (dst_shape.size() == 0) { |
|
|
|
dst_shape.emplace_back(1); |
|
|
|
src0_shape.emplace_back(1); |
|
|
|
src1_shape.emplace_back(1); |
|
|
|
} |
|
|
|
size_t src0_length = 1; |
|
|
|
size_t src1_length = 1; |
|
|
|
for (size_t i = 0; i < src0_shape.size(); ++i) { |
|
|
|
src0_length = src0_length * src0_shape[i]; |
|
|
|
} |
|
|
|
for (size_t i = 0; i < src1_shape.size(); ++i) { |
|
|
|
src1_length = src1_length * src1_shape[i]; |
|
|
|
} |
|
|
|
if (src1_shape.size() != src0_shape.size()) { |
|
|
|
if (src0_shape.size() == 0) { |
|
|
|
if (src0_length == 1 && src0_shape.size() != dst_shape.size()) { |
|
|
|
need_swap_ = true; |
|
|
|
for (size_t i = 0; i < src1_shape.size(); ++i) { |
|
|
|
for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) { |
|
|
|
src0_shape.emplace_back(1); |
|
|
|
} |
|
|
|
} else if (src1_shape.size() == 0) { |
|
|
|
for (size_t i = 0; i < src0_shape.size(); ++i) { |
|
|
|
} else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) { |
|
|
|
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { |
|
|
|
src1_shape.emplace_back(1); |
|
|
|
} |
|
|
|
} else { |
|
|
|
|