| @@ -14,7 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "plugin/device/cpu/kernel/eigen/solve_triangular_cpu_kernel.h" | |||
| #include "plugin/device/cpu/kernel/eigen/matrix_triangular_solve_cpu_kernel.h" | |||
| #include <Eigen/Dense> | |||
| #include <vector> | |||
| #include <string> | |||
| @@ -38,7 +38,7 @@ constexpr auto kAMatrixDimNum = 2; | |||
| constexpr size_t kRowIndex = 2; | |||
| constexpr size_t kColIndex = 1; | |||
| template <typename T> | |||
| void SolveTriangularCpuKernelMod<T>::InitShape(const CNodePtr &kernel_node) { | |||
| void MatrixTriangularSolveCpuKernelMod<T>::InitShape(const CNodePtr &kernel_node) { | |||
| auto a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| // Since the shape check is done in frontend, we can suppose that the shape of a, b here is valid. | |||
| @@ -59,20 +59,30 @@ void SolveTriangularCpuKernelMod<T>::InitShape(const CNodePtr &kernel_node) { | |||
| } | |||
| template <typename T> | |||
| void SolveTriangularCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| void MatrixTriangularSolveCpuKernelMod<T>::InitKernel(const CNodePtr &kernel_node) { | |||
| kernel_name_ = AnfAlgo::GetCNodeName(kernel_node); | |||
| InitShape(kernel_node); | |||
| lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER); | |||
| unit_diagonal_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, UNIT_DIAGONAL); | |||
| const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, TRANS); | |||
| if (trans == "N") { | |||
| trans_ = false; | |||
| } else if (trans == "T") { | |||
| trans_ = true; | |||
| } else if (trans == "C") { | |||
| trans_ = true; | |||
| if (AnfAlgo::HasNodeAttr(ADJOINT, kernel_node)) { | |||
| // MatrixTriangularSolve attribute | |||
| trans_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, ADJOINT); | |||
| if (AnfAlgo::HasNodeAttr(TRANS, kernel_node)) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ | |||
| << "', the attribute 'adjoint' and 'trans' could not exist at the same time."; | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Trans should be in [N, T, C], but got [" << trans << "]."; | |||
| lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER); | |||
| unit_diagonal_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, UNIT_DIAGONAL); | |||
| const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, TRANS); | |||
| if (trans == "N") { | |||
| trans_ = false; | |||
| } else if (trans == "T") { | |||
| trans_ = true; | |||
| } else if (trans == "C") { | |||
| trans_ = true; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', 'trans' should be in ['N', 'T', 'C'], but got [" << trans | |||
| << "]."; | |||
| } | |||
| } | |||
| } | |||
| @@ -96,9 +106,9 @@ inline void solve(const MatrixBase<Derived_a> &a, const MatrixBase<Derived_b> &b | |||
| } | |||
| template <typename T> | |||
| bool SolveTriangularCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> & /* workspace */, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| bool MatrixTriangularSolveCpuKernelMod<T>::Launch(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> & /* workspace */, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| CHECK_KERNEL_INPUTS_NUM(inputs.size(), kSolveTriangularInputsNum, kernel_name_); | |||
| CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kSolveTriangularOutputsNum, kernel_name_); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include "plugin/device/cpu/kernel/cpu_kernel.h" | |||
| @@ -24,10 +24,10 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| class SolveTriangularCpuKernelMod : public NativeCpuKernelMod { | |||
| class MatrixTriangularSolveCpuKernelMod : public NativeCpuKernelMod { | |||
| public: | |||
| SolveTriangularCpuKernelMod() = default; | |||
| ~SolveTriangularCpuKernelMod() override = default; | |||
| MatrixTriangularSolveCpuKernelMod() = default; | |||
| ~MatrixTriangularSolveCpuKernelMod() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| @@ -47,12 +47,20 @@ class SolveTriangularCpuKernelMod : public NativeCpuKernelMod { | |||
| MS_REG_CPU_KERNEL_T( | |||
| SolveTriangular, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SolveTriangularCpuKernelMod, float) | |||
| MatrixTriangularSolveCpuKernelMod, float) | |||
| MS_REG_CPU_KERNEL_T( | |||
| SolveTriangular, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| SolveTriangularCpuKernelMod, double) | |||
| MatrixTriangularSolveCpuKernelMod, double) | |||
| MS_REG_CPU_KERNEL_T( | |||
| MatrixTriangularSolve, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| MatrixTriangularSolveCpuKernelMod, float) | |||
| MS_REG_CPU_KERNEL_T( | |||
| MatrixTriangularSolve, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| MatrixTriangularSolveCpuKernelMod, double) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_SOLVE_TRIANGULAR_CPU_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN_MATRIX_TRIANGULAR_SOLVE_CPU_KERNEL_H_ | |||
| @@ -14,17 +14,17 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "plugin/device/gpu/kernel/math/solve_triangular_gpu_kernel.h" | |||
| #include "plugin/device/gpu/kernel/math/matrix_triangular_solve_gpu_kernel.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SolveTriangular, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| SolveTriangularGpuKernelMod, float) | |||
| MatrixTriangularSolveGpuKernelMod, float) | |||
| MS_REG_GPU_KERNEL_ONE( | |||
| SolveTriangular, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| SolveTriangularGpuKernelMod, double) | |||
| MatrixTriangularSolveGpuKernelMod, double) | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_ | |||
| #include <cublas_v2.h> | |||
| #include <cuda_runtime_api.h> | |||
| #include <type_traits> | |||
| @@ -39,10 +39,10 @@ constexpr size_t kIndexBBuffer = 2; | |||
| constexpr size_t kIndexBTransposeShape = 3; | |||
| constexpr size_t kIndexBTransposeAxis = 4; | |||
| template <typename T> | |||
| class SolveTriangularGpuKernelMod : public NativeGpuKernelMod { | |||
| class MatrixTriangularSolveGpuKernelMod : public NativeGpuKernelMod { | |||
| public: | |||
| SolveTriangularGpuKernelMod() = default; | |||
| ~SolveTriangularGpuKernelMod() = default; | |||
| MatrixTriangularSolveGpuKernelMod() = default; | |||
| ~MatrixTriangularSolveGpuKernelMod() = default; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs, void *stream_ptr) override { | |||
| @@ -159,32 +159,33 @@ class SolveTriangularGpuKernelMod : public NativeGpuKernelMod { | |||
| lda_ = SizeToInt(m_); | |||
| ldb_ = SizeToInt(m_); | |||
| const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans"); | |||
| // converting row major to col major is the same as reverting the trans flag | |||
| if (trans == "N") { | |||
| trans_ = CUBLAS_OP_T; | |||
| } else if (trans == "T") { | |||
| trans_ = CUBLAS_OP_N; | |||
| } else if (trans == "C") { | |||
| trans_ = CUBLAS_OP_N; | |||
| if (AnfAlgo::HasNodeAttr("adjoint", kernel_node)) { | |||
| // MatrixTriangularSolve attribute | |||
| bool trans = AnfAlgo::GetNodeAttr<bool>(kernel_node, "adjoint"); | |||
| // converting row major to col major is the same as reverting the trans flag | |||
| trans_ = trans ? CUBLAS_OP_N : CUBLAS_OP_T; | |||
| if (AnfAlgo::HasNodeAttr("trans", kernel_node)) { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ | |||
| << "', the attribute 'adjoint' and 'trans' could not exist at the same time."; | |||
| } | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "]."; | |||
| } | |||
| bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower"); | |||
| // reverting the trans flag by default, so also flip the lower flag | |||
| lower = !lower; | |||
| if (lower) { | |||
| uplo_ = CUBLAS_FILL_MODE_LOWER; | |||
| } else { | |||
| uplo_ = CUBLAS_FILL_MODE_UPPER; | |||
| } | |||
| bool unit_diagonal = AnfAlgo::GetNodeAttr<bool>(kernel_node, "unit_diagonal"); | |||
| if (unit_diagonal) { | |||
| unit_diagonal_ = CUBLAS_DIAG_UNIT; | |||
| } else { | |||
| unit_diagonal_ = CUBLAS_DIAG_NON_UNIT; | |||
| bool lower = AnfAlgo::GetNodeAttr<bool>(kernel_node, "lower"); | |||
| // reverting the trans flag by default, so also flip the lower flag | |||
| lower = !lower; | |||
| uplo_ = lower ? CUBLAS_FILL_MODE_LOWER : CUBLAS_FILL_MODE_UPPER; | |||
| bool unit_diagonal = AnfAlgo::GetNodeAttr<bool>(kernel_node, "unit_diagonal"); | |||
| unit_diagonal_ = unit_diagonal ? CUBLAS_DIAG_UNIT : CUBLAS_DIAG_NON_UNIT; | |||
| const std::string trans = AnfAlgo::GetNodeAttr<std::string>(kernel_node, "trans"); | |||
| // converting row major to col major is the same as reverting the trans flag | |||
| if (trans == "N") { | |||
| trans_ = CUBLAS_OP_T; | |||
| } else if (trans == "T") { | |||
| trans_ = CUBLAS_OP_N; | |||
| } else if (trans == "C") { | |||
| trans_ = CUBLAS_OP_N; | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', trans should be in [N, T, C], but got [" << trans << "]."; | |||
| } | |||
| } | |||
| InitSizeLists(); | |||
| @@ -263,4 +264,4 @@ class SolveTriangularGpuKernelMod : public NativeGpuKernelMod { | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_TRSM_SOLVE_GPU_KERNEL_H_ | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATRIX_TRIANGULAR_SOLVE_GPU_KERNEL_H_ | |||