| @@ -181,7 +181,7 @@ void ArithmeticCPUKernel::LaunchLess(const std::vector<AddressPtr> &inputs, cons | |||
| T *input2 = reinterpret_cast<T *>(inputs[1]->addr); | |||
| bool *output = reinterpret_cast<bool *>(outputs[0]->addr); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(bool)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| @@ -67,6 +67,52 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa | |||
| } | |||
| } | |||
| bool MKLCPUKernel::BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape, | |||
| std::vector<size_t> *dst_shape) { | |||
| MS_EXCEPTION_IF_NULL(src0_shape); | |||
| MS_EXCEPTION_IF_NULL(src1_shape); | |||
| MS_EXCEPTION_IF_NULL(dst_shape); | |||
| bool need_swap = false; | |||
| if (dst_shape->size() == 0) { | |||
| dst_shape->emplace_back(1); | |||
| src0_shape->emplace_back(1); | |||
| src1_shape->emplace_back(1); | |||
| } | |||
| MS_LOG(DEBUG) << "Binary broadcast in: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape; | |||
| if (src0_shape->size() != dst_shape->size()) { | |||
| need_swap = true; | |||
| for (size_t i = src0_shape->size(); i < dst_shape->size(); ++i) { | |||
| src0_shape->insert(src0_shape->begin(), 1); | |||
| } | |||
| } else if (src1_shape->size() != dst_shape->size()) { | |||
| for (size_t i = src1_shape->size(); i < dst_shape->size(); ++i) { | |||
| src1_shape->insert(src1_shape->begin(), 1); | |||
| } | |||
| } | |||
| if (src0_shape->size() == src1_shape->size()) { | |||
| bool visit_src0 = false; | |||
| bool visit_src1 = false; | |||
| for (size_t i = 0; i < src0_shape->size(); ++i) { | |||
| if (src0_shape->at(i) != src1_shape->at(i)) { | |||
| if (src0_shape->at(i) == 1 && !visit_src1) { | |||
| need_swap = true; | |||
| visit_src0 = true; | |||
| } else if (src1_shape->at(i) == 1 && !visit_src0) { | |||
| need_swap = false; | |||
| visit_src1 = true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid broadcast! " << *src0_shape << " vs " << *src1_shape; | |||
| } | |||
| } | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid broadcast! src0: " << *src0_shape << " src1: " << *src1_shape | |||
| << " dst: " << *dst_shape; | |||
| } | |||
| MS_LOG(DEBUG) << "Binary broadcast out: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape; | |||
| return need_swap; | |||
| } | |||
| dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const { | |||
| dnnl::memory::format_tag mem_tag; | |||
| auto dim_size = dims.size(); | |||
| @@ -32,6 +32,8 @@ class MKLCPUKernel : public CPUKernel { | |||
| ~MKLCPUKernel() override = default; | |||
| protected: | |||
| bool BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape, | |||
| std::vector<size_t> *dst_shape); | |||
| void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector<size_t> &src_shape, | |||
| const std::vector<size_t> &kernel_size, int stride, std::vector<int> *padding_l, | |||
| std::vector<int> *padding_r); | |||
| @@ -25,49 +25,7 @@ void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||
| if (dst_shape.size() == 0) { | |||
| dst_shape.emplace_back(1); | |||
| src0_shape.emplace_back(1); | |||
| src1_shape.emplace_back(1); | |||
| } | |||
| size_t src0_length = 1; | |||
| size_t src1_length = 1; | |||
| for (size_t i = 0; i < src0_shape.size(); ++i) { | |||
| src0_length = src0_length * src0_shape[i]; | |||
| } | |||
| for (size_t i = 0; i < src1_shape.size(); ++i) { | |||
| src1_length = src1_length * src1_shape[i]; | |||
| } | |||
| if (src1_shape.size() != src0_shape.size()) { | |||
| if (src0_length == 1 && src0_shape.size() != dst_shape.size()) { | |||
| need_swap_ = true; | |||
| for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) { | |||
| src0_shape.emplace_back(1); | |||
| } | |||
| } else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) { | |||
| for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { | |||
| src1_shape.emplace_back(1); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; | |||
| } | |||
| } else { | |||
| bool visit_src0 = false; | |||
| bool visit_src1 = false; | |||
| for (size_t i = 0; i < src0_shape.size(); ++i) { | |||
| if (src0_shape[i] != src1_shape[i]) { | |||
| if (src0_shape[i] == 1 && !visit_src1) { | |||
| need_swap_ = true; | |||
| visit_src0 = true; | |||
| } else if (src1_shape[i] == 1 && !visit_src0) { | |||
| need_swap_ = false; | |||
| visit_src1 = true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape); | |||
| dnnl::memory::desc src0_desc; | |||
| dnnl::memory::desc src1_desc; | |||
| if (need_swap_) { | |||
| @@ -25,49 +25,7 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||
| if (dst_shape.size() == 0) { | |||
| dst_shape.emplace_back(1); | |||
| src0_shape.emplace_back(1); | |||
| src1_shape.emplace_back(1); | |||
| } | |||
| size_t src0_length = 1; | |||
| size_t src1_length = 1; | |||
| for (size_t i = 0; i < src0_shape.size(); ++i) { | |||
| src0_length = src0_length * src0_shape[i]; | |||
| } | |||
| for (size_t i = 0; i < src1_shape.size(); ++i) { | |||
| src1_length = src1_length * src1_shape[i]; | |||
| } | |||
| if (src1_shape.size() != src0_shape.size()) { | |||
| if (src0_length == 1 && src0_shape.size() != dst_shape.size()) { | |||
| need_swap_ = true; | |||
| for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) { | |||
| src0_shape.emplace_back(1); | |||
| } | |||
| } else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) { | |||
| for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) { | |||
| src1_shape.emplace_back(1); | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; | |||
| } | |||
| } else { | |||
| bool visit_src0 = false; | |||
| bool visit_src1 = false; | |||
| for (size_t i = 0; i < src0_shape.size(); ++i) { | |||
| if (src0_shape[i] != src1_shape[i]) { | |||
| if (src0_shape[i] == 1 && !visit_src1) { | |||
| need_swap_ = true; | |||
| visit_src0 = true; | |||
| } else if (src1_shape[i] == 1 && !visit_src0) { | |||
| need_swap_ = false; | |||
| visit_src1 = true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape); | |||
| dnnl::memory::desc src0_desc; | |||
| dnnl::memory::desc src1_desc; | |||
| if (need_swap_) { | |||
| @@ -47,6 +47,8 @@ def test_mul(): | |||
| y2 = Tensor(2, mstype.float32) | |||
| x3 = Tensor(2, mstype.float32) | |||
| y3 = Tensor(2, mstype.float32) | |||
| x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32)) | |||
| y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32)) | |||
| mul = Net() | |||
| out = mul(x0, y0).asnumpy() | |||
| exp = x0.asnumpy() * y0.asnumpy() | |||
| @@ -75,3 +77,10 @@ def test_mul(): | |||
| err = np.ones(shape=exp.shape) * 1.0e-5 | |||
| assert np.all(diff < err) | |||
| assert out.shape == exp.shape | |||
| out = mul(x4, y4).asnumpy() | |||
| exp = x4.asnumpy() * y4.asnumpy() | |||
| diff = np.abs(out - exp) | |||
| err = np.ones(shape=exp.shape) * 1.0e-5 | |||
| assert np.all(diff < err) | |||
| assert out.shape == exp.shape | |||
| @@ -45,6 +45,8 @@ def test_tensor_add(): | |||
| y2 = Tensor(2, mstype.float32) | |||
| x3 = Tensor(2, mstype.float32) | |||
| y3 = Tensor(2, mstype.float32) | |||
| x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32)) | |||
| y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32)) | |||
| add = TensorAdd() | |||
| out = add(x0, y0).asnumpy() | |||
| exp = x0.asnumpy() + y0.asnumpy() | |||
| @@ -73,3 +75,10 @@ def test_tensor_add(): | |||
| err = np.ones(shape=exp.shape) * 1.0e-5 | |||
| assert np.all(diff < err) | |||
| assert out.shape == exp.shape | |||
| out = add(x4, y4).asnumpy() | |||
| exp = x4.asnumpy() + y4.asnumpy() | |||
| diff = np.abs(out - exp) | |||
| err = np.ones(shape=exp.shape) * 1.0e-5 | |||
| assert np.all(diff < err) | |||
| assert out.shape == exp.shape | |||