| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_ | |||
| #include "Eigen/Dense" | |||
| #include "Eigen/Core" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| using Eigen::ColMajor; | |||
| using Eigen::Dynamic; | |||
| using Eigen::Lower; | |||
| using Eigen::Map; | |||
| using Eigen::MatrixBase; | |||
| using Eigen::RowMajor; | |||
| using Eigen::UnitLower; | |||
| using Eigen::UnitUpper; | |||
| using Eigen::Upper; | |||
| template <typename T, int Major> | |||
| using Matrix = Eigen::Matrix<T, Dynamic, Dynamic, Major>; | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_EIGEN_COMMON_UTILS_H_ | |||
| @@ -16,13 +16,13 @@ | |||
| #include "backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h" | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h" | |||
| #include "utils/ms_utils.h" | |||
| #include "Eigen/Dense" | |||
| #include "Eigen/LU" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kLUInputsNum = 1; | |||
| constexpr size_t kLUaIndex = 0; | |||
| @@ -73,27 +73,44 @@ template <typename T> | |||
| bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr); | |||
| Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_); | |||
| Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_); | |||
| T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr); | |||
| Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(lu_value, lu_row_, lu_col_); | |||
| Map<Matrix<T, RowMajor>> output_lu(lu_value, lu_row_, lu_col_); | |||
| int *pivots_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr); | |||
| Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_pivots( | |||
| pivots_value, pivots_row_, pivots_col_); | |||
| int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr); | |||
| Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_permutation( | |||
| permutation_value, permutation_row_, permutation_col_); | |||
| Map<Matrix<int, RowMajor>> output_permutation(permutation_value, permutation_row_, permutation_col_); | |||
| if (a_row_ == a_col_) { | |||
| // partial_piv_lu | |||
| output_lu = input_a.lu().matrixLU(); | |||
| output_pivots = input_a.lu().permutationP().indices(); | |||
| auto partial_lu = input_a.lu(); | |||
| auto partial_p = partial_lu.permutationP(); | |||
| output_lu.noalias() = partial_lu.matrixLU(); | |||
| output_permutation.noalias() = partial_p.toDenseMatrix(); | |||
| } else { | |||
| // full_piv_lu | |||
| output_lu = input_a.fullPivLu().matrixLU(); | |||
| output_pivots = input_a.fullPivLu().permutationP().indices(); | |||
| auto full_piv_lu = input_a.fullPivLu(); | |||
| auto full_piv_p = full_piv_lu.permutationP(); | |||
| output_lu.noalias() = full_piv_lu.matrixLU(); | |||
| output_permutation.noalias() = full_piv_p.toDenseMatrix(); | |||
| } | |||
| // calculate permutation array from permutation matrix to indicate scipy's pivots. | |||
| for (int i = 0; i < static_cast<int>(output_permutation.rows()); ++i) { | |||
| if (output_permutation(i, i) != 0) { | |||
| pivots_value[i] = i; | |||
| continue; | |||
| } | |||
| for (int j = 0; j < static_cast<int>(output_permutation.cols()); ++j) { | |||
| if (output_permutation(i, j) != 0) { | |||
| pivots_value[i] = j; | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| output_permutation = output_pivots; | |||
| // here, we note that eigen calculate permutation matrix is col major, so transpose it to row major, | |||
| // but permutation array is based on permutation matrix before transposed, which is consistent to scipy and jax. | |||
| output_permutation.transposeInPlace(); | |||
| if (output_lu.RowsAtCompileTime != 0 && output_lu.ColsAtCompileTime != 0 && output_permutation.size() != 0) { | |||
| return true; | |||
| } | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| @@ -62,4 +62,4 @@ MS_REG_CPU_KERNEL_T(LU, | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_CPU_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_CPU_KERNEL_H_ | |||
| @@ -16,13 +16,13 @@ | |||
| #include "backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h" | |||
| #include <vector> | |||
| #include <string> | |||
| #include "utils/ms_utils.h" | |||
| #include "backend/kernel_compiler/cpu/eigen/eigen_common_utils.h" | |||
| #include "Eigen/Dense" | |||
| #include "Eigen/LU" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| constexpr size_t kLUInputsNum = 2; | |||
| constexpr size_t kLUaIndex = 0; | |||
| @@ -70,6 +70,7 @@ void LUSolverCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| out_row_ = output_lu_shape.at(output_lu_shape.size() - kRowIndex); | |||
| out_col_ = output_lu_shape.at(output_lu_shape.size() - kColIndex); | |||
| } | |||
| trans_ = AnfAlgo ::GetNodeAttr<std::string>(kernel_node, TRANS); | |||
| } | |||
| template <typename T> | |||
| @@ -77,23 +78,28 @@ bool LUSolverCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr); | |||
| Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_); | |||
| Map<Matrix<T, RowMajor>> input_a(a_value, a_row_, a_col_); | |||
| T *b_value = reinterpret_cast<T *>(inputs[kLUbIndex]->addr); | |||
| Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(b_value, b_row_, b_col_); | |||
| Map<Matrix<T, RowMajor>> input_b(b_value, b_row_, b_col_); | |||
| T *output_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr); | |||
| Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(output_lu_value, out_row_, | |||
| out_col_); | |||
| if (a_row_ == a_col_) { | |||
| // partial_piv_lu | |||
| output_lu = input_a.lu().solve(input_b); | |||
| Map<Matrix<T, RowMajor>> output_lu(output_lu_value, out_row_, out_col_); | |||
| if (trans_ == "N") { | |||
| output_lu.noalias() = input_a.template triangularView<UnitLower>().solve(input_b); | |||
| output_lu.noalias() = input_a.template triangularView<Upper>().solve(output_lu); | |||
| } else if (trans_ == "T") { | |||
| output_lu.noalias() = input_a.template triangularView<Upper>().solve(input_b); | |||
| output_lu.noalias() = input_a.template triangularView<UnitLower>().solve(output_lu); | |||
| } else if (trans_ == "C") { | |||
| MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is not supported C: " << trans_; | |||
| } else { | |||
| // full_piv_lu | |||
| output_lu = input_a.fullPivLu().solve(input_b); | |||
| MS_LOG_EXCEPTION << kernel_name_ << " trans_ flag is invalid: " << trans_; | |||
| } | |||
| if (output_lu.RowsAtCompileTime == 0 || output_lu.ColsAtCompileTime == 0) { | |||
| MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid."; | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -14,10 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_SOLVER_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <string> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| @@ -39,6 +40,7 @@ class LUSolverCPUKernel : public CPUKernel { | |||
| size_t b_col_{1}; | |||
| size_t out_row_{1}; | |||
| size_t out_col_{1}; | |||
| std::string trans_{}; | |||
| TypeId dtype_{kNumberTypeFloat32}; | |||
| }; | |||
| @@ -53,4 +55,4 @@ MS_REG_CPU_KERNEL_T( | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_LU_SOLVER_CPU_KERNEL_H_ | |||
| @@ -106,6 +106,8 @@ bool ScatterNdUpdateCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp | |||
| LaunchKernel<float16>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeFloat32) { | |||
| LaunchKernel<float>(inputs, outputs); | |||
| } else if (dtype_ == kNumberTypeInt32) { | |||
| LaunchKernel<int>(inputs, outputs); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_; | |||
| } | |||
| @@ -72,6 +72,13 @@ MS_REG_CPU_KERNEL(TensorScatterUpdate, | |||
| .AddInputAttr(kNumberTypeFloat32) | |||
| .AddOutputAttr(kNumberTypeFloat32), | |||
| ScatterNdUpdateCPUKernel); | |||
| MS_REG_CPU_KERNEL(ScatterNdUpdate, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddInputAttr(kNumberTypeInt32) | |||
| .AddOutputAttr(kNumberTypeInt32), | |||
| ScatterNdUpdateCPUKernel) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -19,11 +19,14 @@ from mindspore import Tensor | |||
| import mindspore.common.dtype as mstype | |||
| from mindspore.ops import PrimitiveWithInfer | |||
| from mindspore.ops import prim_attr_register | |||
| from scipy.linalg import lu_factor | |||
| from scipy.linalg import lu_solve | |||
| from mindspore._checkparam import Validator as validator | |||
| import mindspore.numpy as mnp | |||
| import scipy as scp | |||
| import numpy as np | |||
| import pytest | |||
| np.random.seed(0) | |||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||
| @@ -41,19 +44,14 @@ class LU(PrimitiveWithInfer): | |||
| def __infer__(self, x): | |||
| x_shape = list(x['shape']) | |||
| x_dtype = x['dtype'] | |||
| pivots_shape = [] | |||
| permutation_shape = [] | |||
| ndim = len(x_shape) | |||
| permutation_shape = x_shape | |||
| if ndim == 0: | |||
| pivots_shape = x_shape | |||
| permutation_shape = x_shape | |||
| elif ndim == 1: | |||
| pivots_shape = x_shape[:-1] | |||
| permutation_shape = x_shape[:-1] | |||
| else: | |||
| pivots_shape = x_shape[-2:-1] | |||
| permutation_shape = x_shape[-2:-1] | |||
| output = { | |||
| 'shape': (x_shape, pivots_shape, permutation_shape), | |||
| 'dtype': (x_dtype, mstype.int32, mstype.int32), | |||
| @@ -68,9 +66,10 @@ class LUSolver(PrimitiveWithInfer): | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self): | |||
| def __init__(self, trans: str): | |||
| super().__init__(name="LUSolver") | |||
| self.init_prim_io_names(inputs=['x', 'b'], outputs=['output']) | |||
| self.trans = validator.check_value_type("trans", trans, [str], self.name) | |||
| def __infer__(self, x, b): | |||
| b_shape = list(b['shape']) | |||
| @@ -92,42 +91,128 @@ class LuNet(nn.Cell): | |||
| return self.lu(a) | |||
| class LUSolverNet(nn.Cell): | |||
| def __init__(self): | |||
| super(LUSolverNet, self).__init__() | |||
| self.lu_solver = LUSolver() | |||
| def lu_pivots_to_permutation(pivots, permutation_size: int): | |||
| batch_dims = pivots.shape[:-1] | |||
| k = pivots.shape[-1] | |||
| per = mnp.arange(0, permutation_size) | |||
| permutation = mnp.broadcast_to(per, batch_dims + (permutation_size,)) | |||
| permutation = mnp.array(permutation) | |||
| if permutation_size == 0: | |||
| return permutation | |||
| for i in range(k): | |||
| j = pivots[..., i] | |||
| loc = mnp.ix_(*(mnp.arange(0, b) for b in batch_dims)) | |||
| x = permutation[..., i] | |||
| y = permutation[loc + (j,)] | |||
| permutation[..., i] = y | |||
| permutation[loc + (j,)] = x | |||
| return permutation | |||
| def _lu_solve_core(in_lu, permutation, b, trans): | |||
| m = in_lu.shape[0] | |||
| res_shape = b.shape[1:] | |||
| prod_result = 1 | |||
| for sh in res_shape: | |||
| prod_result *= sh | |||
| x = mnp.reshape(b, (m, prod_result)) | |||
| if trans == 0: | |||
| trans_str = "N" | |||
| x = x[permutation, :] | |||
| elif trans == 1: | |||
| trans_str = "T" | |||
| elif trans == 2: | |||
| trans_str = "C" | |||
| else: | |||
| raise ValueError("trans error, it's value must be 0, 1, 2") | |||
| ms_lu_solve = LUSolver(trans_str) | |||
| output = ms_lu_solve(in_lu, x) | |||
| return mnp.reshape(output, b.shape) | |||
| def _check_lu_shape(in_lu, b): | |||
| if len(in_lu.shape) < 2 or in_lu.shape[-1] != in_lu.shape[-2]: | |||
| raise ValueError("last two dimensions of LU decomposition must be equal.") | |||
| if b.shape is None: | |||
| raise ValueError(" LU decomposition input b's rank must >=1.") | |||
| rhs_vector = in_lu.ndim == b.ndim + 1 | |||
| if rhs_vector: | |||
| if b.shape[-1] != in_lu.shape[-1]: | |||
| raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions") | |||
| mnp.expand_dims(b, axis=1) | |||
| else: | |||
| if b.shape[-2] != in_lu.shape[-1]: | |||
| raise ValueError("LU decomposition: lu matrix and b must have same number of dimensions") | |||
| def construct(self, a, b): | |||
| return self.lu_solver(a, b) | |||
| def lu_factor(a, overwrite_a=False, check_finite=True): | |||
| del overwrite_a, check_finite | |||
| mscp_lu = LuNet() | |||
| m_lu, pivots, _ = mscp_lu(a) | |||
| return m_lu, pivots | |||
| def _match_array(actual, expected, error=0): | |||
| if isinstance(actual, int): | |||
| actual = np.asarray(actual) | |||
| if isinstance(actual, tuple): | |||
| actual = np.asarray(actual) | |||
| if error > 0: | |||
| np.testing.assert_almost_equal(actual, expected, decimal=error) | |||
| else: | |||
| np.testing.assert_equal(actual, expected) | |||
| def lu(a, permute_l=False, overwrite_a=False, check_finite=True): | |||
| del overwrite_a, check_finite | |||
| mscp_lu = LuNet() | |||
| m_lu, _, p = mscp_lu(a) | |||
| m = a.shape[-2] | |||
| n = a.shape[-1] | |||
| k = min(m, n) | |||
| a_dtype = a.dtype | |||
| l = mnp.tril(m_lu, -1)[:, :k] + mnp.eye(m, k, dtype=a_dtype) | |||
| u = mnp.triu(m_lu)[:k, :] | |||
| if permute_l: | |||
| return mnp.matmul(p, l), u | |||
| return p, l, u | |||
| def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): | |||
| del overwrite_b, check_finite | |||
| m_lu, pivots = lu_and_piv | |||
| # 1. check shape | |||
| _check_lu_shape(m_lu, b) | |||
| # here permutation array has been calculated, just use it. | |||
| # 2. calculate permutation | |||
| permutation = pivots | |||
| # 3. rhs_vector | |||
| rhs_vector = m_lu.ndim == b.ndim + 1 | |||
| x = _lu_solve_core(m_lu, permutation, b, trans) | |||
| return x[..., 0] if rhs_vector else x | |||
| def create_full_rank_matrix(m, n, dtype): | |||
| a_rank = 0 | |||
| a = np.random.random((m, n)).astype(dtype) | |||
| while a_rank != m: | |||
| a = (a + np.eye(m, n)).astype(dtype) | |||
| a_rank = np.linalg.matrix_rank(a) | |||
| return a | |||
| def create_sym_pos_matrix(m, n, dtype): | |||
| a = (np.random.random((m, n)) + np.eye(m, n)).astype(dtype) | |||
| return np.dot(a, a.T) | |||
| @pytest.mark.platform_x86_cpu | |||
| @pytest.mark.parametrize('n', [10, 20]) | |||
| @pytest.mark.parametrize('dtype', [np.float32, np.float64]) | |||
| def test_lu_net(n: int, dtype: Generic): | |||
| def test_square_lu_net(n: int, dtype: Generic): | |||
| """ | |||
| Feature: ALL To ALL | |||
| Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1] | |||
| Expectation: the result match to scipy | |||
| """ | |||
| a = (np.random.random((n, n)) + np.eye(n)).astype(dtype) | |||
| s_lu, _ = lu_factor(a) | |||
| a = create_full_rank_matrix(n, n, dtype) | |||
| s_lu, _ = scp.linalg.lu_factor(a) | |||
| mscp_lu_net = LuNet() | |||
| tensor_a = Tensor(a) | |||
| mscp_lu, _, _ = mscp_lu_net(tensor_a) | |||
| _match_array(mscp_lu.asnumpy(), s_lu, error=4) | |||
| assert np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3) | |||
| @pytest.mark.platform_x86_cpu | |||
| @@ -139,13 +224,24 @@ def test_lu_solver_net(n: int, dtype: Generic): | |||
| Description: test cases for lu_solve test cases for A[N,N]x = b[N,1] | |||
| Expectation: the result match to scipy | |||
| """ | |||
| a = (np.random.random((n, n)) + np.eye(n)).astype(dtype) | |||
| a = create_full_rank_matrix(n, n, dtype) | |||
| b = np.random.random((n, 1)).astype(dtype) | |||
| s_lu, s_piv = lu_factor(a) | |||
| lu_factor_x = (s_lu, s_piv) | |||
| scp_x = lu_solve(lu_factor_x, b) | |||
| mscp_lu_net = LUSolverNet() | |||
| s_lu, s_piv = scp.linalg.lu_factor(a) | |||
| tensor_a = Tensor(a) | |||
| tensor_b = Tensor(b) | |||
| mscp_x = mscp_lu_net(tensor_a, tensor_b) | |||
| _match_array(mscp_x.asnumpy(), scp_x, error=4) | |||
| mscp_lu_net = LuNet() | |||
| mscp_lu, pivots, _ = mscp_lu_net(tensor_a) | |||
| np.allclose(mscp_lu.asnumpy(), s_lu, rtol=1.e-3, atol=1.e-3) | |||
| lu_factor_x = (s_lu, s_piv) | |||
| msc_lu_factor = (mscp_lu, pivots) | |||
| scp_x = scp.linalg.lu_solve(lu_factor_x, b) | |||
| mscp_x = lu_solve(msc_lu_factor, tensor_b) | |||
| real_b = mnp.dot(tensor_a, mscp_x) | |||
| expected_b = np.dot(a, scp_x) | |||
| assert np.allclose(real_b.asnumpy(), expected_b, rtol=1.e-3, atol=1.e-3) | |||
| assert np.allclose(mscp_x.asnumpy(), scp_x, rtol=1.e-3, atol=1.e-3) | |||