| @@ -3,6 +3,20 @@ include_directories(${CMAKE_SOURCE_DIR}/mindspore/core) | |||
| include_directories(${CMAKE_CURRENT_SOURCE_DIR}) | |||
| include_directories(${CMAKE_BINARY_DIR}) | |||
| if(ENABLE_CPU) | |||
| if("${X86_64_SIMD}" STREQUAL "sse") | |||
| add_compile_definitions(ENABLE_SSE) | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -msse4.2") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -msse4.2") | |||
| endif() | |||
| if("${X86_64_SIMD}" STREQUAL "avx") | |||
| add_compile_definitions(ENABLE_SSE) | |||
| add_compile_definitions(ENABLE_AVX) | |||
| set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -mavx2") | |||
| set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mavx2") | |||
| endif() | |||
| endif() | |||
| if(ENABLE_ACL) | |||
| set(ASCEND_PATH /usr/local/Ascend) | |||
| include_directories(${ASCEND_PATH}/acllib/include) | |||
| @@ -29,7 +43,7 @@ if(ENABLE_GPU) | |||
| find_package(CUDA REQUIRED) | |||
| find_package(Threads) | |||
| if(${CUDA_VERSION} VERSION_LESS ${MS_REQUIRE_CUDA_VERSION}) | |||
| message(FATAL_ERROR "The minimum CUDA version ${MS_REQUIRE_CUDA_VERSION} is required, \ | |||
| message(FATAL_ERROR "The minimum CUDA version ${MS_REQUIRE_CUDA_VERSION} is required, \ | |||
| but only CUDA ${CUDA_VERSION} found.") | |||
| endif() | |||
| enable_language(CUDA) | |||
| @@ -22,21 +22,16 @@ void BiasAddCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| bias_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| if (input_shape_.size() == 4) { | |||
| data_shape_ = 4; | |||
| } else if (input_shape_.size() == 2) { | |||
| data_shape_ = 2; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "bias add input data format should be NCHW or NC"; | |||
| } | |||
| if (input_shape_.size() != 2 && input_shape_.size() != 4) { | |||
| MS_LOG(EXCEPTION) << "bias add input shape nchw or nc"; | |||
| data_shape_ = input_shape_.size(); | |||
| if (input_shape_.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Input tensor's rank must be at least 2 for 'BiasAdd' Op, but input tensor's rank is " | |||
| << input_shape_.size(); | |||
| } | |||
| if (bias_shape_.size() != 1) { | |||
| MS_LOG(EXCEPTION) << "bias shape invalid"; | |||
| MS_LOG(EXCEPTION) << "Bias's rank must be 1 for 'BiasAdd' Op, but bias' rank is" << bias_shape_.size(); | |||
| } | |||
| if (input_shape_[1] != bias_shape_[0]) { | |||
| MS_LOG(EXCEPTION) << "bias shape not match"; | |||
| MS_LOG(EXCEPTION) << "Bias shape not match, bias shape must be equal to C channel's shape"; | |||
| } | |||
| } | |||
| @@ -50,22 +45,36 @@ bool BiasAddCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const std:: | |||
| auto bias_addr = reinterpret_cast<float *>(inputs[1]->addr); | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| if (data_shape_ == 4) { | |||
| size_t h_size = input_shape_[3]; | |||
| size_t c_size = input_shape_[2] * h_size; | |||
| size_t n_size = input_shape_[1] * c_size; | |||
| size_t hw_size = input_shape_[2] * input_shape_[3]; | |||
| size_t n_offset = 0; | |||
| if (input_shape_.size() > 2) { | |||
| size_t hw_size = 1; | |||
| for (size_t i = 2; i < input_shape_.size(); ++i) { | |||
| hw_size *= input_shape_[i]; | |||
| } | |||
| size_t c_size = input_shape_[1]; | |||
| for (size_t n = 0; n < input_shape_[0]; ++n) { | |||
| size_t c_offset = 0; | |||
| for (size_t c = 0; c < input_shape_[1]; ++c) { | |||
| for (size_t hw = 0; hw < hw_size; ++hw) { | |||
| size_t offset = n_offset + c_offset + hw; | |||
| output_addr[offset] = src_addr[offset] + bias_addr[c]; | |||
| for (size_t c = 0; c < c_size; ++c) { | |||
| size_t offset = n * c_size * hw_size + c * hw_size; | |||
| size_t hw = 0; | |||
| #ifdef ENABLE_AVX | |||
| constexpr size_t C8NUM = 8; | |||
| size_t hw8 = hw_size / C8NUM * C8NUM; | |||
| const float *in_ptr = src_addr + offset; | |||
| float *out_ptr = output_addr + offset; | |||
| for (; hw < hw8; hw += C8NUM) { | |||
| __m256 src_r1 = _mm256_loadu_ps(in_ptr); | |||
| __m256 bias_r2 = _mm256_set1_ps(bias_addr[c]); | |||
| __m256 dst_r3 = _mm256_add_ps(src_r1, bias_r2); | |||
| _mm256_storeu_ps(out_ptr, dst_r3); | |||
| in_ptr += C8NUM; | |||
| out_ptr += C8NUM; | |||
| } | |||
| #endif | |||
| for (; hw < hw_size; ++hw) { | |||
| output_addr[offset + hw] = src_addr[offset + hw] + bias_addr[c]; | |||
| } | |||
| c_offset += c_size; | |||
| } | |||
| n_offset += n_size; | |||
| } | |||
| } else { | |||
| size_t n_offset = 0; | |||
| @@ -33,7 +33,7 @@ class BiasAddCPUKernel : public CPUKernel { | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| private: | |||
| uint8_t data_shape_{0}; | |||
| size_t data_shape_{0}; | |||
| std::vector<size_t> input_shape_; | |||
| std::vector<size_t> bias_shape_; | |||
| }; | |||
| @@ -21,8 +21,9 @@ namespace kernel { | |||
| void BiasAddGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| if (input_shape_.size() != 4 && input_shape_.size() != 2) { | |||
| MS_LOG(EXCEPTION) << "input data format not support"; | |||
| if (input_shape_.size() < 2) { | |||
| MS_LOG(EXCEPTION) << "Input tensor's rank must be at least 2 for 'BiasAddGrad' Op, but input tensor's rank is " | |||
| << input_shape_.size(); | |||
| } | |||
| } | |||
| @@ -34,23 +35,21 @@ bool BiasAddGradCPUKernel::Launch(const std::vector<AddressPtr> &inputs, const s | |||
| auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); | |||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | |||
| if (input_shape_.size() == 4) { | |||
| size_t h_size = input_shape_[3]; | |||
| size_t c_size = h_size * input_shape_[2]; | |||
| size_t n_size = c_size * input_shape_[1]; | |||
| size_t hw_size = input_shape_[2] * input_shape_[3]; | |||
| size_t c_offset = 0; | |||
| for (size_t c = 0; c < input_shape_[1]; ++c) { | |||
| if (input_shape_.size() > 2) { | |||
| size_t hw_size = 1; | |||
| for (size_t i = 2; i < input_shape_.size(); ++i) { | |||
| hw_size *= input_shape_[i]; | |||
| } | |||
| size_t c_size = input_shape_[1]; | |||
| for (size_t c = 0; c < c_size; ++c) { | |||
| output_addr[c] = 0; | |||
| size_t n_offset = 0; | |||
| for (size_t n = 0; n < input_shape_[0]; ++n) { | |||
| size_t offset = n * c_size * hw_size + c * hw_size; | |||
| for (size_t hw = 0; hw < hw_size; ++hw) { | |||
| size_t offset = c_offset + n_offset + hw; | |||
| output_addr[c] += input_addr[offset]; | |||
| output_addr[c] += input_addr[offset + hw]; | |||
| } | |||
| n_offset += n_size; | |||
| } | |||
| c_offset += c_size; | |||
| } | |||
| } else if (input_shape_.size() == 2) { | |||
| for (size_t c = 0; c < input_shape_[1]; ++c) { | |||
| @@ -35,7 +35,7 @@ class Net(nn.Cell): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add1(): | |||
| def test_bias_add4d(): | |||
| x = np.ones([2, 3, 4, 4]).astype(np.float32) | |||
| b = np.array([1, 1, 1]).astype(np.float32) | |||
| bias_add = Net() | |||
| @@ -48,7 +48,7 @@ def test_bias_add1(): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add2(): | |||
| def test_bias_add2d(): | |||
| x = np.ones([2, 3]).astype(np.float32) | |||
| b = np.array([1, 1, 1]).astype(np.float32) | |||
| bias_add = Net() | |||
| @@ -56,3 +56,52 @@ def test_bias_add2(): | |||
| expect_output = np.ones([2, 3]).astype(np.float32) * 2 | |||
| print(output) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add3d(): | |||
| x = np.ones([2, 3, 4]).astype(np.float32) | |||
| b = np.array([1, 1, 1]).astype(np.float32) | |||
| bias_add = Net() | |||
| output = bias_add(Tensor(x), Tensor(b)) | |||
| expect_output = np.ones([2, 3, 4]).astype(np.float32) * 2 | |||
| print(output) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add5d(): | |||
| x = np.ones([2, 5, 4, 4, 4]).astype(np.float32) | |||
| b = np.array([1, 1, 1, 1, 1]).astype(np.float32) | |||
| bias_add = Net() | |||
| output = bias_add(Tensor(x), Tensor(b)) | |||
| expect_output = np.ones([2, 5, 4, 4, 4]).astype(np.float32) * 2 | |||
| print(output) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add6d(): | |||
| x = np.ones([2, 4, 4, 4, 4, 1]).astype(np.float32) | |||
| b = np.array([1, 1, 1, 1]).astype(np.float32) | |||
| bias_add = Net() | |||
| output = bias_add(Tensor(x), Tensor(b)) | |||
| expect_output = np.ones([2, 4, 4, 4, 4, 1]).astype(np.float32) * 2 | |||
| print(output) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add7d(): | |||
| x = np.ones([2, 4, 4, 4, 4, 1, 2]).astype(np.float32) | |||
| b = np.array([1, 1, 1, 1]).astype(np.float32) | |||
| bias_add = Net() | |||
| output = bias_add(Tensor(x), Tensor(b)) | |||
| expect_output = np.ones([2, 4, 4, 4, 4, 1, 2]).astype(np.float32) * 2 | |||
| print(output) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add execute failed, please check current code commit" | |||
| @@ -35,7 +35,7 @@ class Net(nn.Cell): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add_grad1(): | |||
| def test_bias_add_grad2d(): | |||
| dout = np.ones([2, 3]).astype(np.float32) | |||
| bias_add_grad = Net() | |||
| output = bias_add_grad(Tensor(dout)) | |||
| @@ -47,10 +47,32 @@ def test_bias_add_grad1(): | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add_grad2(): | |||
| def test_bias_add_grad4d(): | |||
| dout = np.ones([2, 3, 4, 4]).astype(np.float32) | |||
| bias_add_grad = Net() | |||
| output = bias_add_grad(Tensor(dout)) | |||
| expect_output = np.array([32., 32., 32.]).astype(np.float32) | |||
| print(output.asnumpy()) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add_grad execute failed, please check current code commit" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add_grad5d(): | |||
| dout = np.ones([2, 3, 4, 4, 2]).astype(np.float32) | |||
| bias_add_grad = Net() | |||
| output = bias_add_grad(Tensor(dout)) | |||
| expect_output = np.array([64., 64., 64.]).astype(np.float32) | |||
| print(output.asnumpy()) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add_grad execute failed, please check current code commit" | |||
| @pytest.mark.level0 | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| def test_bias_add_grad7d(): | |||
| dout = np.ones([2, 3, 4, 4, 2, 1, 10]).astype(np.float32) | |||
| bias_add_grad = Net() | |||
| output = bias_add_grad(Tensor(dout)) | |||
| expect_output = np.array([640., 640., 640.]).astype(np.float32) | |||
| print(output.asnumpy()) | |||
| assert np.all(output.asnumpy() == expect_output), "bias_add_grad execute failed, please check current code commit" | |||