Browse Source

!25010 add lu && lu_solver && cholesky cpu kernel

Merge pull request !25010 from zhuzhongrui/pub_master
tags/v1.6.0
i-robot Gitee 4 years ago
parent
commit
9f6ed1b94d
11 changed files with 765 additions and 84 deletions
  1. +2
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h
  2. +97
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc
  3. +53
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.h
  4. +88
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.cc
  5. +57
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h
  6. +48
    -53
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc
  7. +22
    -23
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h
  8. +100
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.cc
  9. +56
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h
  10. +176
    -0
      tests/st/ops/cpu/test_cholesky_op.py
  11. +66
    -8
      tests/st/ops/cpu/test_lu_op.py

+ 2
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel.h View File

@@ -94,6 +94,8 @@ constexpr char CLOSED[] = "closed";
constexpr char NA_OPTION[] = "na_option";
constexpr char ASCENDING[] = "ascending";
constexpr char PCT[] = "pct";
constexpr char LOWER[] = "lower";
constexpr char CLEAN[] = "clean";

struct ParallelSearchInfo {
double min_cost_time{DBL_MAX};


+ 97
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.cc View File

@@ -0,0 +1,97 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.h"
#include <vector>
#include "utils/ms_utils.h"
#include "Eigen/Dense"
#include "Eigen/Cholesky"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputsNum = 1;
constexpr size_t kInputIndex = 0;
constexpr size_t kOutputsNum = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kDefaultShape = 1;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
} // namespace

template <typename T>
void CholeskyCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is invalid.";
}
if (shape.size() == kDefaultShape) {
*row = shape.front();
*col = shape.front();
} else {
*row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex);
}
if (*row != *col) {
MS_LOG_EXCEPTION << kernel_name_ << "input shape is invalid: " << *row << ", " << *col;
}
return;
}

template <typename T>
void CholeskyCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kernel_name_);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputIndex);
InitMatrixInfo(input_shape, &input_row_, &input_col_);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex);
InitMatrixInfo(output_shape, &output_row_, &output_col_);
lower_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, LOWER);
clean_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, CLEAN);
}

template <typename T>
bool CholeskyCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input_value = reinterpret_cast<T *>(inputs[kInputIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input(input_value, input_row_,
input_col_);

T *output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output(output_value, output_row_,
output_col_);
Eigen::LLT<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> llt;
llt.compute(input);

if (clean_) {
if (lower_) {
output = llt.matrixL();
} else {
output = llt.matrixU();
}
} else {
output = llt.matrixLLT();
}
if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
return true;
}
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
} // namespace kernel
} // namespace mindspore

+ 53
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_cpu_kernel.h View File

@@ -0,0 +1,53 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
template <typename T>
class CholeskyCPUKernel : public CPUKernel {
public:
CholeskyCPUKernel() = default;
~CholeskyCPUKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

private:
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
bool lower_{true};
bool clean_{false};
size_t input_row_{1};
size_t input_col_{1};
size_t output_row_{1};
size_t output_col_{1};
TypeId dtype_{kNumberTypeFloat32};
};

MS_REG_CPU_KERNEL_T(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskyCPUKernel, float)

MS_REG_CPU_KERNEL_T(Cholesky, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskyCPUKernel, double)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_

+ 88
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.cc View File

@@ -0,0 +1,88 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h"
#include "Eigen/Dense"
#include "Eigen/Cholesky"
namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kInputsNum = 2;
constexpr size_t kInputAIndex = 0;
constexpr size_t kInputBIndex = 1;
constexpr size_t kOutputsNum = 1;
constexpr size_t kOutputIndex = 0;
constexpr size_t kDefaultShape = 1;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
} // namespace

template <typename T>
void CholeskySolverCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is invalid.";
}
if (shape.size() == kDefaultShape) {
*row = shape.front();
} else {
*row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex);
}
return;
}

template <typename T>
void CholeskySolverCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kInputsNum, kernel_name_);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kOutputsNum, kernel_name_);
auto input_a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputAIndex);
InitMatrixInfo(input_a_shape, &input_a_row_, &input_a_col_);
auto input_b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kInputBIndex);
InitMatrixInfo(input_b_shape, &input_b_row_, &input_b_col_);
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, kOutputIndex);
InitMatrixInfo(output_shape, &output_row_, &output_col_);
}

template <typename T>
bool CholeskySolverCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input_value = reinterpret_cast<T *>(inputs[kInputAIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input(input_value, input_a_row_,
input_a_col_);

T *input_b_value = reinterpret_cast<T *>(inputs[kInputBIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(input_b_value, input_b_row_,
input_b_col_);

T *output_value = reinterpret_cast<T *>(outputs[kOutputIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output(output_value, output_row_,
output_col_);
Eigen::LLT<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> llt(input);

output = llt.solve(input_b);

if (output.RowsAtCompileTime != 0 && output.ColsAtCompileTime != 0) {
return true;
}
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
} // namespace kernel
} // namespace mindspore

+ 57
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/cholesky_solve_cpu_kernel.h View File

@@ -0,0 +1,57 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
template <typename T>
class CholeskySolverCPUKernel : public CPUKernel {
public:
CholeskySolverCPUKernel() = default;
~CholeskySolverCPUKernel() override = default;

void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

private:
void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);

size_t input_a_row_{1};
size_t input_a_col_{1};
size_t input_b_row_{1};
size_t input_b_col_{1};
size_t output_row_{1};
size_t output_col_{1};
TypeId dtype_{kNumberTypeFloat32};
};

MS_REG_CPU_KERNEL_T(
CholeskySolver,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CholeskySolverCPUKernel, float)
MS_REG_CPU_KERNEL_T(
CholeskySolver,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CholeskySolverCPUKernel, double)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CHOLESKY_CPU_KERNEL_H_

+ 48
- 53
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.cc View File

@@ -24,9 +24,8 @@ namespace mindspore {
namespace kernel {

namespace {
constexpr size_t kLUInputsNum = 2;
constexpr size_t kLUInputsNum = 1;
constexpr size_t kLUaIndex = 0;
constexpr size_t kLUbIndex = 1;
constexpr size_t kLUOutputsNum = 3;
constexpr size_t kLuIndex = 0;
constexpr size_t kPivotsIndex = 1;
@@ -36,8 +35,23 @@ constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
} // namespace

void LUCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_LOG(INFO) << "init lu kernel";
template <typename T>
void LUCPUKernel<T>::InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col) {
if (shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << "shape is invalid.";
}
if (shape.size() == kLUDefaultShape) {
*row = shape.front();
*col = 1;
} else {
*row = shape.at(shape.size() - kRowIndex);
*col = shape.at(shape.size() - kColIndex);
}
return;
}

template <typename T>
void LUCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
@@ -45,64 +59,45 @@ void LUCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CHECK_KERNEL_INPUTS_NUM(input_num, kLUInputsNum, kernel_name_);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kLUOutputsNum, kernel_name_);
auto a_input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kLUaIndex);
auto b_input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kLUbIndex);
if (a_input_shape.empty() || b_input_shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << " input a or b matrix shape invalid.";
}
if (a_input_shape.size() == kLUDefaultShape) {
a_row_ = a_input_shape.front();
} else {
a_row_ = a_input_shape.at(a_input_shape.size() - kRowIndex);
a_col_ = a_input_shape.at(a_input_shape.size() - kColIndex);
}
if (b_input_shape.size() == kLUDefaultShape) {
b_row_ = b_input_shape.front();
} else {
b_row_ = b_input_shape.at(b_input_shape.size() - kRowIndex);
b_col_ = b_input_shape.at(b_input_shape.size() - kColIndex);
}
auto output_lu_shape = AnfAlgo::GetOutputInferShape(kernel_node, kLuIndex);
if (output_lu_shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
if (output_lu_shape.size() == kLUDefaultShape) {
out_row_ = output_lu_shape.front();
} else {
out_row_ = output_lu_shape.at(output_lu_shape.size() - kRowIndex);
out_col_ = output_lu_shape.at(output_lu_shape.size() - kColIndex);
}
auto a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kLUaIndex);
InitMatrixInfo(a_shape, &a_row_, &a_col_);
auto lu_shape = AnfAlgo::GetOutputInferShape(kernel_node, kLuIndex);
InitMatrixInfo(lu_shape, &lu_row_, &lu_col_);
auto pivots_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPivotsIndex);
InitMatrixInfo(pivots_shape, &pivots_row_, &pivots_col_);
auto permutation_shape = AnfAlgo::GetOutputInferShape(kernel_node, kPermutationIndex);
InitMatrixInfo(permutation_shape, &permutation_row_, &permutation_col_);
}

template <typename T>
void LUCPUKernel::LaunchLu(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
bool LUCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_);

T *b_value = reinterpret_cast<T *>(inputs[kLUbIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(b_value, b_row_, b_col_);
T *output_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(output_lu_value, out_row_,
out_col_);
auto lu_x = input_a.lu();
output_lu = lu_x.solve(input_b);
if (output_lu.RowsAtCompileTime == 0 || output_lu.ColsAtCompileTime == 0) {
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
return;
}
T *lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(lu_value, lu_row_, lu_col_);
int *pivots_value = reinterpret_cast<int *>(outputs[kPivotsIndex]->addr);
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_pivots(
pivots_value, pivots_row_, pivots_col_);
int *permutation_value = reinterpret_cast<int *>(outputs[kPermutationIndex]->addr);
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_permutation(
permutation_value, permutation_row_, permutation_col_);

bool LUCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
if (dtype_ == kNumberTypeFloat32) {
LaunchLu<float>(inputs, outputs);
} else if (dtype_ == kNumberTypeFloat64) {
LaunchLu<double>(inputs, outputs);
if (a_row_ == a_col_) {
// partial_piv_lu
output_lu = input_a.lu().matrixLU();
output_pivots = input_a.lu().permutationP().indices();
} else {
MS_LOG_EXCEPTION << kernel_name_ << " unsupported " << dtype_;
// full_piv_lu
output_lu = input_a.fullPivLu().matrixLU();
output_pivots = input_a.fullPivLu().permutationP().indices();
}
output_permutation = output_pivots;
if (output_lu.RowsAtCompileTime != 0 && output_lu.ColsAtCompileTime != 0 && output_permutation.size() != 0) {
return true;
}
return true;
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
} // namespace kernel
} // namespace mindspore

+ 22
- 23
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_cpu_kernel.h View File

@@ -23,6 +23,7 @@

namespace mindspore {
namespace kernel {
template <typename T>
class LUCPUKernel : public CPUKernel {
public:
LUCPUKernel() = default;
@@ -32,34 +33,32 @@ class LUCPUKernel : public CPUKernel {
const std::vector<AddressPtr> &outputs) override;

private:
template <typename T>
void LaunchLu(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs);

void InitMatrixInfo(const std::vector<size_t> &shape, size_t *row, size_t *col);
size_t a_row_{1};
size_t a_col_{1};
size_t b_row_{1};
size_t b_col_{1};
size_t out_row_{1};
size_t out_col_{1};
size_t lu_row_{1};
size_t lu_col_{1};
size_t pivots_row_{1};
size_t pivots_col_{1};
size_t permutation_row_{1};
size_t permutation_col_{1};
TypeId dtype_{kNumberTypeFloat32};
};

MS_REG_CPU_KERNEL(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUCPUKernel);
MS_REG_CPU_KERNEL(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUCPUKernel);
MS_REG_CPU_KERNEL_T(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUCPUKernel, float);
MS_REG_CPU_KERNEL_T(LU,
KernelAttr()
.AddInputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeFloat64)
.AddOutputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeInt32),
LUCPUKernel, double);
} // namespace kernel
} // namespace mindspore



+ 100
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.cc View File

@@ -0,0 +1,100 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h"
#include <vector>
#include "utils/ms_utils.h"
#include "Eigen/Dense"
#include "Eigen/LU"

namespace mindspore {
namespace kernel {

namespace {
constexpr size_t kLUInputsNum = 2;
constexpr size_t kLUaIndex = 0;
constexpr size_t kLUbIndex = 1;
constexpr size_t kLUOutputsNum = 1;
constexpr size_t kLuIndex = 0;
constexpr size_t kLUDefaultShape = 1;
constexpr size_t kRowIndex = 2;
constexpr size_t kColIndex = 1;
} // namespace

template <typename T>
void LUSolverCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = AnfAlgo::GetCNodeName(kernel_node);
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
CHECK_KERNEL_INPUTS_NUM(input_num, kLUInputsNum, kernel_name_);
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
CHECK_KERNEL_OUTPUTS_NUM(output_num, kLUOutputsNum, kernel_name_);
auto a_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kLUaIndex);
auto b_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, kLUbIndex);
if (a_shape.empty() || b_shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << " input a or b matrix shape invalid.";
}
if (a_shape.size() == kLUDefaultShape) {
a_row_ = a_shape.front();
} else {
a_row_ = a_shape.at(a_shape.size() - kRowIndex);
a_col_ = a_shape.at(a_shape.size() - kColIndex);
}
if (b_shape.size() == kLUDefaultShape) {
b_row_ = b_shape.front();
} else {
b_row_ = b_shape.at(b_shape.size() - kRowIndex);
b_col_ = b_shape.at(b_shape.size() - kColIndex);
}
auto output_lu_shape = AnfAlgo::GetOutputInferShape(kernel_node, kLuIndex);
if (output_lu_shape.empty()) {
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
if (output_lu_shape.size() == kLUDefaultShape) {
out_row_ = output_lu_shape.front();
} else {
out_row_ = output_lu_shape.at(output_lu_shape.size() - kRowIndex);
out_col_ = output_lu_shape.at(output_lu_shape.size() - kColIndex);
}
}

template <typename T>
bool LUSolverCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
T *a_value = reinterpret_cast<T *>(inputs[kLUaIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_a(a_value, a_row_, a_col_);

T *b_value = reinterpret_cast<T *>(inputs[kLUbIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> input_b(b_value, b_row_, b_col_);
T *output_lu_value = reinterpret_cast<T *>(outputs[kLuIndex]->addr);
Eigen::Map<Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> output_lu(output_lu_value, out_row_,
out_col_);
if (a_row_ == a_col_) {
// partial_piv_lu
output_lu = input_a.lu().solve(input_b);
} else {
// full_piv_lu
output_lu = input_a.fullPivLu().solve(input_b);
}
if (output_lu.RowsAtCompileTime == 0 || output_lu.ColsAtCompileTime == 0) {
MS_LOG_EXCEPTION << kernel_name_ << " output lu shape invalid.";
}
return true;
}
} // namespace kernel
} // namespace mindspore

+ 56
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/eigen/lu_solve_cpu_kernel.h View File

@@ -0,0 +1,56 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LU_SOLVER_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_

#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
template <typename T>
class LUSolverCPUKernel : public CPUKernel {
public:
LUSolverCPUKernel() = default;
~LUSolverCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;

private:
size_t a_row_{1};
size_t a_col_{1};
size_t b_row_{1};
size_t b_col_{1};
size_t out_row_{1};
size_t out_col_{1};
TypeId dtype_{kNumberTypeFloat32};
};

MS_REG_CPU_KERNEL_T(
LUSolver,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
LUSolverCPUKernel, float);
MS_REG_CPU_KERNEL_T(
LUSolver,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
LUSolverCPUKernel, double);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_EIGEN3_LUSOLVER_CPU_KERNEL_H_

+ 176
- 0
tests/st/ops/cpu/test_cholesky_op.py View File

@@ -0,0 +1,176 @@
import mindspore.context as context
import mindspore.nn as nn
from mindspore.ops import PrimitiveWithInfer
from mindspore.ops import prim_attr_register
from mindspore._checkparam import Validator as validator
from mindspore import Tensor
import numpy as np
import scipy as scp

context.set_context(mode=context.GRAPH_MODE, device_target="CPU")


class Cholesky(PrimitiveWithInfer):
"""
Inner API for Cholesky base class.
"""

@prim_attr_register
def __init__(self, lower=False, clean=False, split_dim=0):
super().__init__(name="Cholesky")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.clean = validator.check_value_type("clean", clean, [bool], self.clean)
self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim)
self.init_prim_io_names(inputs=['x'], outputs=['y'])

def infer_shape(self, x_shape):
if self.split_dim != 0:
height = x_shape[0]
width = x_shape[1]
if height <= self.split_dim:
out_shape = [1, height, width]
else:
batch = height // self.split_dim
if height != batch * self.split_dim:
batch += 1
out_shape = [batch, self.split_dim, self.split_dim]
else:
out_shape = x_shape
return out_shape

def __infer__(self, x):
x_shape = x['shape']
x_dtype = x['dtype']
return {
'shape': tuple(x_shape),
'dtype': x_dtype,
'value': None
}


class CholeskySolver(PrimitiveWithInfer):
"""
Inner API for CholeskySolver class.
"""

@prim_attr_register
def __init__(self, lower=False, split_dim=0):
super().__init__(name="CholeskySolver")
self.lower = validator.check_value_type("lower", lower, [bool], self.lower)
self.split_dim = validator.check_value_type("split_dim", split_dim, [int], self.split_dim)
self.init_prim_io_names(inputs=['x'], outputs=['y'])

def __infer__(self, x, b):
b_shape = b['shape']
x_dtype = x['dtype']
return {
'shape': tuple(b_shape),
'dtype': x_dtype,
'value': None
}


class CholeskyNet(nn.Cell):
def __init__(self, lower=False, clean=False, split_dim=0):
super(CholeskyNet, self).__init__()
self.cholesky = Cholesky(lower, clean, split_dim)

def construct(self, x):
return self.cholesky(x)


class CholeskySolverNet(nn.Cell):
def __init__(self, lower=False, split_dim=0):
super(CholeskySolverNet, self).__init__()
self.cholesky_solver = CholeskySolver(lower, split_dim)

def construct(self, c, b):
return self.cholesky_solver(c, b)


def cho_factor(a, lower=False, overwrite_a=False, check_finite=True):
"""
ompute the Cholesky decomposition of a matrix, to use in cho_solve.
Returns a matrix containing the Cholesky decomposition
"""
cholesky_net = CholeskyNet(lower=lower, clean=False)
c = cholesky_net(a)
return c, lower


def cholesky(a, lower=False, overwrite_a=False, check_finite=True):
"""
Compute the Cholesky decomposition of a matrix.
Returns the Cholesky decomposition
"""
cholesky_net = CholeskyNet(lower=lower, clean=True)
c = cholesky_net(a)
return c


def cho_solve(c_and_lower, b, overwrite_b=False, check_finite=True):
"""Solve the linear equations A x = b, given the Cholesky factorization of A.

Parameters
----------
c_and_lower: (c, lower) tuple, (array, bool)
Cholesky factorization of a, as given by cho_factor
b : array
Right-hand side
overwrite_b : bool, optional
Whether to overwrite data in b (may improve performance)
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Returns
-------
x : array
The solution to the system A x = b

See also
--------
cho_factor : Cholesky factorization of a matrix

"""
(c, lower) = c_and_lower
cholesky_solver_net = CholeskySolverNet(lower=lower)
x = cholesky_solver_net(c, b)
return x


def test_cholesky():
"""
Feature: ALL TO ALL
Description: test cases for cholesky [N,N]
Expectation: the result match scipy cholesky
"""
a = np.array([[4, 12, -6], [12, 37, -43], [-16, -43, 98]], dtype=np.float32)
tensor_a = Tensor(a)
scp_c_1, _ = scp.linalg.cho_factor(a, lower=True)
mscp_c_1, _ = cho_factor(tensor_a, lower=True)

scp_c_2 = scp.linalg.cholesky(a, lower=True)
mscp_c_2 = cholesky(tensor_a, lower=True)
assert np.allclose(scp_c_1, mscp_c_1.asnumpy())
assert np.allclose(scp_c_2, mscp_c_2.asnumpy())


def test_cholesky_solver():
"""
Feature: ALL TO ALL
Description: test cases for cholesky solver [N,N]
Expectation: the result match scipy cholesky_solve
"""
a = np.array([[9, 3, 1, 5], [3, 7, 5, 1], [1, 5, 9, 2], [5, 1, 2, 6]], dtype=np.float32)
b = np.array([1, 1, 1, 1], dtype=np.float32)
tensor_a = Tensor(a)
tensor_b = Tensor(b)
scp_c, lower = scp.linalg.cho_factor(a, lower=True)
scp_x = scp.linalg.cho_solve((scp_c, lower), b)

mscp_c, mscp_lower = cho_factor(tensor_a, lower=True)
mscp_x = cho_solve((tensor_a, mscp_lower), tensor_b)
assert np.allclose(scp_c, mscp_c.asnumpy())
assert np.allclose(scp_x, mscp_x.asnumpy())

+ 66
- 8
tests/st/ops/cpu/test_lu_op.py View File

@@ -36,28 +36,69 @@ class LU(PrimitiveWithInfer):
@prim_attr_register
def __init__(self):
super().__init__(name="LU")
self.init_prim_io_names(inputs=['x', 'b'], outputs=['lu', 'pivots', 'permutation'])
self.init_prim_io_names(inputs=['x'], outputs=['lu', 'pivots', 'permutation'])

def __infer__(self, x, b):
b_shape = list(b['shape'])
def __infer__(self, x):
x_shape = list(x['shape'])
x_dtype = x['dtype']
pivots_shape = []
permutation_shape = []
ndim = len(x_shape)
if ndim == 0:
pivots_shape = x_shape
permutation_shape = x_shape
elif ndim == 1:
pivots_shape = x_shape[:-1]
permutation_shape = x_shape[:-1]
else:
pivots_shape = x_shape[-2:-1]
permutation_shape = x_shape[-2:-1]

output = {
'shape': (b_shape, pivots_shape, permutation_shape),
'shape': (x_shape, pivots_shape, permutation_shape),
'dtype': (x_dtype, mstype.int32, mstype.int32),
'value': None
}
return output


class LUSolver(PrimitiveWithInfer):
"""
LUSolver for Ax = b
"""

@prim_attr_register
def __init__(self):
super().__init__(name="LUSolver")
self.init_prim_io_names(inputs=['x', 'b'], outputs=['output'])

def __infer__(self, x, b):
b_shape = list(b['shape'])
x_dtype = x['dtype']
output = {
'shape': tuple(b_shape),
'dtype': x_dtype,
'value': None
}
return output


class LuNet(nn.Cell):
def __init__(self):
super(LuNet, self).__init__()
self.lu = LU()

def construct(self, a):
return self.lu(a)


class LUSolverNet(nn.Cell):
def __init__(self):
super(LUSolverNet, self).__init__()
self.lu_solver = LUSolver()

def construct(self, a, b):
return self.lu(a, b)
return self.lu_solver(a, b)


def _match_array(actual, expected, error=0):
@@ -75,7 +116,24 @@ def _match_array(actual, expected, error=0):
@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_net(n: int, dtype: Generic):
def test_lu_net(n: int, dtype: Generic):
"""
Feature: ALL To ALL
Description: test cases for lu decomposition test cases for A[N,N]x = b[N,1]
Expectation: the result match to scipy
"""
a = (np.random.random((n, n)) + np.eye(n)).astype(dtype)
s_lu, _ = lu_factor(a)
mscp_lu_net = LuNet()
tensor_a = Tensor(a)
mscp_lu, _, _ = mscp_lu_net(tensor_a)
_match_array(mscp_lu.asnumpy(), s_lu, error=4)


@pytest.mark.platform_x86_cpu
@pytest.mark.parametrize('n', [10, 20])
@pytest.mark.parametrize('dtype', [np.float32, np.float64])
def test_lu_solver_net(n: int, dtype: Generic):
"""
Feature: ALL To ALL
Description: test cases for lu_solve test cases for A[N,N]x = b[N,1]
@@ -86,8 +144,8 @@ def test_net(n: int, dtype: Generic):
s_lu, s_piv = lu_factor(a)
lu_factor_x = (s_lu, s_piv)
scp_x = lu_solve(lu_factor_x, b)
mscp_lu_net = LuNet()
mscp_lu_net = LUSolverNet()
tensor_a = Tensor(a)
tensor_b = Tensor(b)
mscp_x, _, _ = mscp_lu_net(tensor_a, tensor_b)
mscp_x = mscp_lu_net(tensor_a, tensor_b)
_match_array(mscp_x.asnumpy(), scp_x, error=4)

Loading…
Cancel
Save