Merge pull request !28923 from zhuzhongrui/pub_masterfeature/build-system-rewrite
| @@ -38,19 +38,30 @@ constexpr int kZeroThreshold = INT32_MIN; | |||
| template <typename T> | |||
| void LUCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) { | |||
| constexpr size_t lu_min_dim = 1; | |||
| constexpr size_t lu_max_dim = 3; | |||
| if (shape.size() < lu_min_dim || shape.size() > lu_max_dim) { | |||
| if (shape.size() <= lu_min_dim) { | |||
| MS_LOG_EXCEPTION << kernel_name_ << "shape is " << shape.size() << " which is invalid."; | |||
| } | |||
| if (shape.size() == lu_max_dim) { | |||
| batch_ = shape.front(); | |||
| *row = shape.at(lu_min_dim); | |||
| *col = shape.at(lu_max_dim - 1); | |||
| return; | |||
| constexpr size_t lu_reverse_row_dim = 2; | |||
| *row = shape.at(shape.size() - lu_reverse_row_dim); | |||
| *col = shape.at(shape.size() - 1); | |||
| batch_size_ = lu_min_dim; | |||
| for (int batch = 0; batch < static_cast<int>(shape.size() - lu_reverse_row_dim); ++batch) { | |||
| batch_size_ *= shape.at(batch); | |||
| } | |||
| } | |||
| template <typename T> | |||
| void LUCPUKernel<T>::InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) { | |||
| constexpr size_t pivot_min_dim = 1; | |||
| if (shape.size() < pivot_min_dim) { | |||
| MS_LOG_EXCEPTION << kernel_name_ << "pivots shape is " << shape.size() << " which is invalid."; | |||
| } | |||
| *row = 1; | |||
| if (shape.size() == pivot_min_dim) { | |||
| *col = shape.front(); | |||
| } else { | |||
| *col = shape.back(); | |||
| } | |||
| batch_ = 1; | |||
| *row = shape.front(); | |||
| *col = shape.at(lu_min_dim); | |||
| } | |||
| template <typename T> | |||
| @@ -66,10 +77,10 @@ void LUCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| InitMatrixInfo(a_shape, &a_row_, &a_col_); | |||
| auto lu_shape = AnfAlgo::GetOutputInferShape(kernel_node, kLuIndex); | |||
| InitMatrixInfo(lu_shape, &lu_row_, &lu_col_); | |||
| auto pivots_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex); | |||
| InitMatrixInfo(pivots_shape, &pivots_row_, &pivots_col_); | |||
| auto permutation_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPermutationIndex); | |||
| InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_); | |||
| auto pivots_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex); | |||
| InitPivotVecInfo(pivots_shape, &pivots_row_, &pivots_col_); | |||
| } | |||
| template <typename T> | |||
| @@ -124,7 +135,7 @@ bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| int *batch_permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr); | |||
| T *lu_ori_wk = reinterpret_cast<T *>(workspace[kLuIndex]->addr); | |||
| T *lu_trans_wk = reinterpret_cast<T *>(workspace[kPivotsIndex]->addr); | |||
| for (size_t batch = 0; batch < batch_; ++batch) { | |||
| for (size_t batch = 0; batch < batch_size_; ++batch) { | |||
| T *a_value = batch_a_value + batch * a_row_ * a_col_; | |||
| T *lu_value = batch_lu_value + batch * lu_row_ * lu_col_; | |||
| // pivots permutation value | |||
| @@ -34,11 +34,12 @@ class LUCPUKernel : public CPUKernel { | |||
| private: | |||
| void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col); | |||
| void InitPivotVecInfo(const std::vector<size_t> &shape, size_t *row, size_t *col); | |||
| void InitInputOutputSize(const CNodePtr &kernel_node) override; | |||
| T GetPermutatedValue(const T *lu_value, const std::vector<int> &per_value, size_t i, size_t j); | |||
| bool UpdateMajorPermutation(T *lu_value, std::vector<int> *const per_value, int *pivots, size_t k, size_t rows); | |||
| void SetPermutatedValue(T *lu_value, const std::vector<int> &per_value, size_t i, size_t j, const T &value); | |||
| size_t batch_{1}; | |||
| size_t batch_size_{1}; | |||
| size_t a_row_{1}; | |||
| size_t a_col_{1}; | |||
| size_t lu_row_{1}; | |||
| @@ -68,7 +68,7 @@ class LUGpuKernel : public GpuKernel { | |||
| "malloc input shape workspace failed"); | |||
| CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, | |||
| cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_ * m_ * n_ * unit_size_, | |||
| cudaMemcpyAsync(batch_output_addr, batch_input_addr, batch_size_ * m_ * n_ * unit_size_, | |||
| cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)), | |||
| "cudaMemcpyAsync failed in LUGpuKernel::Launch."); | |||
| @@ -87,7 +87,7 @@ class LUGpuKernel : public GpuKernel { | |||
| } | |||
| // 5. malloc device working space of getrf | |||
| d_work_ = reinterpret_cast<T *>(device::gpu::GPUMemoryAllocator::GetInstance().AllocTensorMem(unit_size_ * lwork_)); | |||
| for (size_t batch = 0; batch < batch_; ++batch) { | |||
| for (size_t batch = 0; batch < batch_size_; ++batch) { | |||
| T *output_addr = batch_output_addr + batch * m_ * n_; | |||
| int *permutation_addr = batch_permutation_addr + batch * k_ * k_; | |||
| int *piv_output_addr = batch_piv_output_addr + batch * k_; | |||
| @@ -177,18 +177,15 @@ class LUGpuKernel : public GpuKernel { | |||
| private: | |||
| bool InitInputSize(const std::vector<size_t> &in_shape) { | |||
| constexpr size_t lu_min_dim = 1; | |||
| constexpr size_t lu_max_dim = 3; | |||
| if (in_shape.size() < lu_min_dim || in_shape.size() > lu_max_dim) { | |||
| MS_LOG_EXCEPTION << kernel_name_ << "shape is " << in_shape.size() << " which is invalid."; | |||
| if (in_shape.size() <= lu_min_dim) { | |||
| MS_LOG_EXCEPTION << kernel_name_ << " input shape is " << in_shape.size() << " which is invalid."; | |||
| } | |||
| if (in_shape.size() == lu_max_dim) { | |||
| batch_ = in_shape.front(); | |||
| lu_row_ = in_shape.at(lu_min_dim); | |||
| lu_col_ = in_shape.at(lu_max_dim - 1); | |||
| } else { | |||
| batch_ = 1; | |||
| lu_row_ = in_shape.front(); | |||
| lu_col_ = in_shape.at(lu_min_dim); | |||
| constexpr size_t lu_reverse_row_dim = 2; | |||
| lu_row_ = in_shape.at(in_shape.size() - lu_reverse_row_dim); | |||
| lu_col_ = in_shape.at(in_shape.size() - 1); | |||
| batch_size_ = lu_min_dim; | |||
| for (int batch = 0; batch < static_cast<int>(in_shape.size() - lu_reverse_row_dim); ++batch) { | |||
| batch_size_ *= in_shape.at(batch); | |||
| } | |||
| // set matrix row or col to be lead dimension | |||
| m_ = SizeToInt(lu_row_); | |||
| @@ -201,16 +198,16 @@ class LUGpuKernel : public GpuKernel { | |||
| } | |||
| void InitSizeLists() override { | |||
| size_t input_size = batch_ * lu_row_ * lu_col_ * unit_size_; | |||
| size_t input_size = batch_size_ * lu_row_ * lu_col_ * unit_size_; | |||
| input_size_list_.push_back(input_size); | |||
| size_t output_size = batch_ * lu_row_ * lu_col_ * unit_size_; | |||
| size_t output_size = batch_size_ * lu_row_ * lu_col_ * unit_size_; | |||
| size_t output_piv_size = 0; | |||
| if (pivot_on_) { | |||
| output_piv_size = batch_ * k_ * sizeof(int); | |||
| output_piv_size = batch_size_ * k_ * sizeof(int); | |||
| } | |||
| size_t output_permutation_size = batch_ * k_ * k_ * sizeof(int); | |||
| size_t output_permutation_size = batch_size_ * k_ * k_ * sizeof(int); | |||
| output_size_list_.resize(kDim3); | |||
| output_size_list_[kDim0] = output_size; | |||
| output_size_list_[kDim1] = output_piv_size; | |||
| @@ -229,7 +226,7 @@ class LUGpuKernel : public GpuKernel { | |||
| } | |||
| size_t unit_size_{sizeof(T)}; | |||
| size_t batch_{1}; | |||
| size_t batch_size_{1}; | |||
| size_t lu_row_{0}; | |||
| size_t lu_col_{0}; | |||
| size_t k_{0}; | |||
| @@ -300,15 +300,9 @@ class LU(PrimitiveWithInfer): | |||
| def __infer__(self, x): | |||
| x_shape = list(x['shape']) | |||
| x_dtype = x['dtype'] | |||
| ndim = len(x_shape) | |||
| if ndim in (1, 2): | |||
| k_shape = min(x_shape[0], x_shape[1]) | |||
| permutation_shape = (k_shape, k_shape) | |||
| pivots_shape = (1, k_shape) | |||
| else: | |||
| k_shape = min(x_shape[1], x_shape[2]) | |||
| permutation_shape = (x_shape[0], k_shape, k_shape) | |||
| pivots_shape = (x_shape[0], 1, k_shape) | |||
| k_shape = min(x_shape[-1], x_shape[-2]) | |||
| permutation_shape = x_shape[:-2] + [k_shape, k_shape] | |||
| pivots_shape = x_shape[:-2] + [k_shape] | |||
| output = { | |||
| 'shape': (x_shape, pivots_shape, permutation_shape), | |||
| 'dtype': (x_dtype, mstype.int32, mstype.int32), | |||
| @@ -243,9 +243,9 @@ def test_lu(shape: (int, int), dtype): | |||
| @pytest.mark.platform_x86_gpu_training | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.env_onecard | |||
| @pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5)]) | |||
| @pytest.mark.parametrize('shape', [(3, 4, 4), (3, 4, 5), (2, 3, 4, 5)]) | |||
| @pytest.mark.parametrize('dtype', [onp.float32, onp.float64]) | |||
| def test_batch_lu(shape: (int, int, int), dtype): | |||
| def test_batch_lu(shape, dtype): | |||
| """ | |||
| Feature: ALL To ALL | |||
| Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] | |||
| @@ -255,13 +255,18 @@ def test_batch_lu(shape: (int, int, int), dtype): | |||
| b_s_p = list() | |||
| b_s_l = list() | |||
| b_s_u = list() | |||
| for a in b_a: | |||
| tmp = onp.zeros(b_a.shape[:-2]) | |||
| for index, _ in onp.ndenumerate(tmp): | |||
| a = b_a[index] | |||
| s_p, s_l, s_u = osp.linalg.lu(a) | |||
| b_s_p.append(s_p) | |||
| b_s_l.append(s_l) | |||
| b_s_u.append(s_u) | |||
| tensor_b_a = Tensor(onp.array(b_a)) | |||
| b_m_p, b_m_l, b_m_u = msp.linalg.lu(tensor_b_a) | |||
| b_s_p = onp.asarray(b_s_p).reshape(b_m_p.shape) | |||
| b_s_l = onp.asarray(b_s_l).reshape(b_m_l.shape) | |||
| b_s_u = onp.asarray(b_s_u).reshape(b_m_u.shape) | |||
| rtol = 1.e-5 | |||
| atol = 1.e-5 | |||
| assert onp.allclose(b_m_p.asnumpy(), b_s_p, rtol=rtol, atol=atol) | |||