Browse Source

add parallel for some CPU ops

tags/v1.2.0-rc1
zhaoting 5 years ago
parent
commit
c62baec9a4
11 changed files with 907 additions and 865 deletions
  1. +13
    -35
      mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc
  2. +1
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h
  3. +257
    -247
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc
  4. +19
    -19
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h
  5. +165
    -153
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc
  6. +0
    -1
      mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h
  7. +17
    -199
      mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc
  8. +300
    -155
      mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h
  9. +38
    -15
      mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_cpu_kernel.cc
  10. +68
    -32
      mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.cc
  11. +29
    -8
      mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc

+ 13
- 35
mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.cc View File

@@ -16,7 +16,6 @@
#include "backend/kernel_compiler/cpu/adam_cpu_kernel.h"

#include <cmath>
#include <thread>
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "utils/ms_utils.h"
@@ -25,16 +24,19 @@ namespace mindspore {
namespace kernel {
template <typename T>
void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * (1 - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
if (use_nesterov) {
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
} else {
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
m[i] += (gradient[i] - m[i]) * (1 - beta1);
v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2);
if (use_nesterov) {
var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon);
} else {
var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon);
}
}
}
};
CPUKernelUtils::ParallelFor(task, size);
}

void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) {
@@ -84,31 +86,7 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,

// multithreading
size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return false;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return false;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam<float>, this, var, m, v, new_lr, beta1, beta2, epsilon,
gradient, start, end));
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}

LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens);
return true;
}
} // namespace kernel


+ 1
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/adam_cpu_kernel.h View File

@@ -29,7 +29,7 @@ class AdamCPUKernel : public CPUKernel {
~AdamCPUKernel() override = default;
template <typename T>
void LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient,
size_t start, size_t end);
size_t size);
void InitKernel(const CNodePtr &kernel_node) override;

bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,


+ 257
- 247
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.cc View File

@@ -15,7 +15,6 @@
*/
#include <cmath>
#include <string>
#include <thread>
#include <map>
#include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"
@@ -23,227 +22,285 @@
namespace mindspore {
namespace kernel {
template <typename T>
void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] + input2[i];
input1[i] = out[i];
}
void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = input1[i] + input2[i];
input1[i] = out[i];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] + input2[idx[1]];
}
void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] + input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] - input2[idx[1]];
}
void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] - input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] * input2[idx[1]];
}
void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] * input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
out[i] = dividend / divisor;
}
out[i] = dividend / divisor;
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
out[i] = dividend / divisor;
}
out[i] = dividend / divisor;
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto dividend = input1[idx[0]];
auto divisor = input2[idx[1]];
if (divisor == 0) {
if (dividend == 0) {
out[i] = std::numeric_limits<T>::quiet_NaN();
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
}
if (std::numeric_limits<T>::has_infinity) {
out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity();
} else {
out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min();
}
continue;
out[i] = floor(dividend / divisor);
}
out[i] = floor(dividend / divisor);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto x = static_cast<double>(input1[idx[0]]);
auto y = static_cast<double>(input2[idx[1]]);
auto data_div = x / y;
auto data_div_min = data_div < 0.0 ? data_div : 0.0;
auto data_div_max = data_div > 0.0 ? data_div : 0.0;
auto data_div_max_floor = floor(data_div_max);
auto data_div_min_ceil = ceil(data_div_min);
auto data_div_res = data_div_max_floor + data_div_min_ceil;
out[i] = static_cast<T>(x - data_div_res * y);
}
void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto x = static_cast<double>(input1[idx[0]]);
auto y = static_cast<double>(input2[idx[1]]);
auto data_div = x / y;
auto data_div_min = data_div < 0.0 ? data_div : 0.0;
auto data_div_max = data_div > 0.0 ? data_div : 0.0;
auto data_div_max_floor = floor(data_div_max);
auto data_div_min_ceil = ceil(data_div_min);
auto data_div_res = data_div_max_floor + data_div_min_ceil;
out[i] = static_cast<T>(x - data_div_res * y);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto x = static_cast<double>(input1[idx[0]]);
auto y = static_cast<double>(input2[idx[1]]);
out[i] = static_cast<T>(std::pow(x, y));
}
void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
auto x = static_cast<double>(input1[idx[0]]);
auto y = static_cast<double>(input2[idx[1]]);
out[i] = static_cast<T>(std::pow(x, y));
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Less(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] < input2[idx[1]];
}
void ArithmeticCPUKernel::Less(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] < input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] == input2[idx[1]];
}
void ArithmeticCPUKernel::Equal(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] == input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] != input2[idx[1]];
}
void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] != input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] && input2[idx[1]];
}
void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] && input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] || input2[idx[1]];
}
void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] || input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
T diff = input1[idx[0]] - input2[idx[1]];
out[i] = diff * diff;
}
void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
T diff = input1[idx[0]] - input2[idx[1]];
out[i] = diff * diff;
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] > input2[idx[1]];
}
void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] > input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] >= input2[idx[1]];
}
void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] >= input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] <= input2[idx[1]];
}
void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = input1[idx[0]] <= input2[idx[1]];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = atan2(input1[idx[0]], input2[idx[1]]);
}
void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
std::vector<size_t> idx;
GenIndex(i, &idx);
out[i] = atan2(input1[idx[0]], input2[idx[1]]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

static const std::map<std::string, OperateType> kArithmeticBinOpTypeMap = {
{prim::kPrimGreater->name(), GREATER},
{prim::kPrimAdd->name(), ADD},
@@ -352,49 +409,25 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &input
T *input1 = reinterpret_cast<T *>(inputs[0]->addr);
T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);

size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(bool)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
if (operate_type_ == LESS) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Less<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == EQUAL) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == NOTEQUAL) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == GREATER) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Greater<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == GREATEREQUAL) {
threads.emplace_back(
std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LESSEQUAL) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LessEqual<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LOGICALAND) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalAnd<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == LOGICALOR) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalOr<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
if (operate_type_ == LESS) {
Less<T>(input1, input2, output, lens);
} else if (operate_type_ == EQUAL) {
Equal<T>(input1, input2, output, lens);
} else if (operate_type_ == NOTEQUAL) {
NotEqual<T>(input1, input2, output, lens);
} else if (operate_type_ == GREATER) {
Greater<T>(input1, input2, output, lens);
} else if (operate_type_ == GREATEREQUAL) {
GreaterEqual<T>(input1, input2, output, lens);
} else if (operate_type_ == LESSEQUAL) {
LessEqual<T>(input1, input2, output, lens);
} else if (operate_type_ == LOGICALAND) {
LogicalAnd<T>(input1, input2, output, lens);
} else if (operate_type_ == LOGICALOR) {
LogicalOr<T>(input1, input2, output, lens);
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}
}

@@ -409,53 +442,30 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co
T *output = reinterpret_cast<T *>(outputs[0]->addr);

size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
if (operate_type_ == ADD) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Add<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == SUB) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Sub<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == MUL) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == REALDIV) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::RealDiv<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == DIV) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == FLOORDIV) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::FloorDiv<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == MOD) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mod<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == POW) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Pow<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ASSIGNADD) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == ATAN2) {
threads.emplace_back(std::thread(&ArithmeticCPUKernel::Atan2<T>, this, input1, input2, output, start, end));
} else if (operate_type_ == SQUAREDDIFFERENCE) {
threads.emplace_back(
std::thread(&ArithmeticCPUKernel::SquaredDifference<T>, this, input1, input2, output, start, end));
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
if (operate_type_ == ADD) {
Add<T>(input1, input2, output, lens);
} else if (operate_type_ == SUB) {
Sub<T>(input1, input2, output, lens);
} else if (operate_type_ == MUL) {
Mul<T>(input1, input2, output, lens);
} else if (operate_type_ == REALDIV) {
RealDiv<T>(input1, input2, output, lens);
} else if (operate_type_ == DIV) {
Div<T>(input1, input2, output, lens);
} else if (operate_type_ == FLOORDIV) {
FloorDiv<T>(input1, input2, output, lens);
} else if (operate_type_ == MOD) {
Mod<T>(input1, input2, output, lens);
} else if (operate_type_ == POW) {
Pow<T>(input1, input2, output, lens);
} else if (operate_type_ == ASSIGNADD) {
AssignAdd<T>(input1, input2, output, lens);
} else if (operate_type_ == ATAN2) {
Atan2<T>(input1, input2, output, lens);
} else if (operate_type_ == SQUAREDDIFFERENCE) {
SquaredDifference<T>(input1, input2, output, lens);
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}
}
} // namespace kernel


+ 19
- 19
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h View File

@@ -40,43 +40,43 @@ class ArithmeticCPUKernel : public CPUKernel {
private:
void GenIndex(size_t num, std::vector<size_t> *tmp);
template <typename T>
void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Sub(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Add(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Add(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Mul(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end);
void RealDiv(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Div(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Div(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end);
void FloorDiv(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Mod(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Mod(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Pow(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Pow(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end);
void AssignAdd(T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end);
void Atan2(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void Less(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void Equal(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void NotEqual(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end);
void SquaredDifference(const T *input1, const T *input2, T *out, size_t size);
template <typename T>
void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void Greater(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void GreaterEqual(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void LessEqual(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void LogicalAnd(const T *input1, const T *input2, bool *out, size_t size);
template <typename T>
void LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end);
void LogicalOr(const T *input1, const T *input2, bool *out, size_t size);
std::vector<size_t> input_shape0_;
std::vector<size_t> input_shape1_;
std::vector<size_t> input_element_num0_;


+ 165
- 153
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.cc View File

@@ -24,152 +24,212 @@ namespace mindspore {
namespace kernel {
namespace {
template <typename T>
void Square(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = in[i] * in[i];
}
void Square(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = in[i] * in[i];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Sign(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
if (in[i] < 0) {
out[i] = -1;
} else if (in[i] > 0) {
out[i] = 1;
} else {
out[i] = 0;
void Sign(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
if (in[i] < 0) {
out[i] = -1;
} else if (in[i] > 0) {
out[i] = 1;
} else {
out[i] = 0;
}
}
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Neg(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = -in[i];
}
void Neg(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = -in[i];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void LogicalNot(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = !in[i];
}
void LogicalNot(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = !in[i];
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void OnesLike(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(1);
}
void OnesLike(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(1);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ZerosLike(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(0);
}
void ZerosLike(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(0);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Floor(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(floor(in[i]));
}
void Floor(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(floor(in[i]));
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Reciprocal(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(1.0 / in[i]);
}
void Reciprocal(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(1.0 / in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Gelu(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T x = in[i];
auto double_x = static_cast<T>(x);
T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x));
out[i] = x * ((T)1.0 + tanh_res) / (T)2.0;
}
void Gelu(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
T x = in[i];
auto double_x = static_cast<T>(x);
T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x));
out[i] = x * ((T)1.0 + tanh_res) / (T)2.0;
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Asin(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = asin(in[i]);
}
void Asin(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = asin(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void ACos(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = acos(in[i]);
}
void ACos(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = acos(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Atan(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = atan(in[i]);
}
void Atan(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = atan(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Sin(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = sin(in[i]);
}
void Sin(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = sin(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Cos(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = cos(in[i]);
}
void Cos(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = cos(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Tan(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = tan(in[i]);
}
void Tan(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = tan(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Sinh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = sinh(in[i]);
}
void Sinh(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = sinh(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Cosh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = cosh(in[i]);
}
void Cosh(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = cosh(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Asinh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = asinh(in[i]);
}
void Asinh(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = asinh(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Acosh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = acosh(in[i]);
}
void Acosh(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = acosh(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename T>
void Atanh(const T *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = atanh(in[i]);
}
void Atanh(const T *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = atanh(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}
} // namespace

@@ -223,79 +283,31 @@ void ArithmeticSelfCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &i
T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;

auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
if (operate_type_ == LOGICALNOT) {
threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end));
}
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
LogicalNot<T>(input, output, lens);
return;
}

template <typename T>
void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
if (target_dtype_ == kNumberTypeBool) {
LaunchKernelLogic<T>(inputs, outputs);
return;
}
T *input = reinterpret_cast<T *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;

auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
static const std::map<OperateType, std::function<void(const T *in, T *out, size_t start, size_t end)>>
kArithmeticOpFuncMap = {{SQUARE, Square<T>}, {SIGN, Sign<T>},
{NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>},
{ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>},
{FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>},
{GELU, Gelu<T>}, {SIN, Sin<T>},
{COS, Cos<T>}, {TAN, Tan<T>},
{ASIN, Asin<T>}, {ACOS, ACos<T>},
{ATAN, Atan<T>}, {SINH, Sinh<T>},
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>}};

while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end));
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
static const std::map<OperateType, std::function<void(const T *in, T *out, size_t size)>> kArithmeticOpFuncMap = {
{SQUARE, Square<T>}, {SIGN, Sign<T>},
{NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>},
{ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>},
{FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>},
{GELU, Gelu<T>}, {SIN, Sin<T>},
{COS, Cos<T>}, {TAN, Tan<T>},
{ASIN, Asin<T>}, {ACOS, ACos<T>},
{ATAN, Atan<T>}, {SINH, Sinh<T>},
{COSH, Cosh<T>}, {ASINH, Asinh<T>},
{ACOSH, Acosh<T>}, {ATANH, Atanh<T>}};
if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) {
kArithmeticOpFuncMap.at(operate_type_)(input, output, lens);
} else {
MS_LOG(EXCEPTION) << "Not support " << operate_type_;
}
}
} // namespace kernel


+ 0
- 1
mindspore/ccsrc/backend/kernel_compiler/cpu/arithmetic_self_cpu_kernel.h View File

@@ -34,7 +34,6 @@ class ArithmeticSelfCPUKernel : public CPUKernel {

template <typename T>
void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

template <typename T>
void LaunchKernelLogic(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);



+ 17
- 199
mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.cc View File

@@ -16,220 +16,38 @@
#include <cmath>
#include <map>
#include <string>
#include <thread>
#include "backend/kernel_compiler/cpu/cast_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"

namespace mindspore {
namespace kernel {
template <typename S, typename T>
void Cast(const S *in, T *out, size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(in[i]);
}
void Cast(const S *in, T *out, size_t size) {
auto task = [&](size_t start, size_t end) {
for (size_t i = start; i < end; i++) {
out[i] = static_cast<T>(in[i]);
}
};
CPUKernelUtils::ParallelFor(task, size);
}

template <typename S, typename T>
void LaunchCast(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) {
S *input = reinterpret_cast<S *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();

size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
auto max_thread_num = std::thread::hardware_concurrency();
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
std::vector<std::thread> threads;
if (thread_num < 1) {
MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num;
return;
}
threads.reserve(thread_num);
size_t start = 0;
size_t once_compute_size = (lens + thread_num - 1) / thread_num;
if (once_compute_size < 1) {
MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size;
return;
}
while (start < lens) {
size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size);
threads.emplace_back(std::thread(Cast<S, T>, input, output, start, end));
start += once_compute_size;
}
for (size_t i = 0; i < threads.size(); ++i) {
threads[i].join();
}
}

void CastCPUKernel::InitKernel(const CNodePtr &kernel_node) {
void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0);
target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
}

bool CastCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
using TypePair =
std::function<void(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
std::map<TypeId, std::map<TypeId, TypePair>> mode_map;
mode_map[kNumberTypeBool][kNumberTypeFloat16] = LaunchCast<bool, float16>;
mode_map[kNumberTypeBool][kNumberTypeFloat32] = LaunchCast<bool, float>;
mode_map[kNumberTypeBool][kNumberTypeFloat64] = LaunchCast<bool, double>;
mode_map[kNumberTypeBool][kNumberTypeInt8] = LaunchCast<bool, int8_t>;
mode_map[kNumberTypeBool][kNumberTypeInt16] = LaunchCast<bool, int16_t>;
mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast<bool, int32_t>;
mode_map[kNumberTypeBool][kNumberTypeInt64] = LaunchCast<bool, int64_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt8] = LaunchCast<bool, uint8_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt16] = LaunchCast<bool, uint16_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt32] = LaunchCast<bool, uint32_t>;
mode_map[kNumberTypeBool][kNumberTypeUInt64] = LaunchCast<bool, uint64_t>;
mode_map[kNumberTypeBool][kNumberTypeBool] = LaunchCast<bool, bool>;

mode_map[kNumberTypeFloat16][kNumberTypeFloat16] = LaunchCast<float16, float16>;
mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast<float16, float>;
mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast<float16, double>;
mode_map[kNumberTypeFloat16][kNumberTypeInt8] = LaunchCast<float16, int8_t>;
mode_map[kNumberTypeFloat16][kNumberTypeInt16] = LaunchCast<float16, int16_t>;
mode_map[kNumberTypeFloat16][kNumberTypeInt32] = LaunchCast<float16, int32_t>;
mode_map[kNumberTypeFloat16][kNumberTypeInt64] = LaunchCast<float16, int64_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt8] = LaunchCast<float16, uint8_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt16] = LaunchCast<float16, uint16_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt32] = LaunchCast<float16, uint32_t>;
mode_map[kNumberTypeFloat16][kNumberTypeUInt64] = LaunchCast<float16, uint64_t>;
mode_map[kNumberTypeFloat16][kNumberTypeBool] = LaunchCast<float16, bool>;

mode_map[kNumberTypeFloat32][kNumberTypeFloat16] = LaunchCast<float, float16>;
mode_map[kNumberTypeFloat32][kNumberTypeFloat32] = LaunchCast<float, float>;
mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast<float, double>;
mode_map[kNumberTypeFloat32][kNumberTypeInt8] = LaunchCast<float, int8_t>;
mode_map[kNumberTypeFloat32][kNumberTypeInt16] = LaunchCast<float, int16_t>;
mode_map[kNumberTypeFloat32][kNumberTypeInt32] = LaunchCast<float, int32_t>;
mode_map[kNumberTypeFloat32][kNumberTypeInt64] = LaunchCast<float, int64_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt8] = LaunchCast<float, uint8_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt16] = LaunchCast<float, uint16_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt32] = LaunchCast<float, uint32_t>;
mode_map[kNumberTypeFloat32][kNumberTypeUInt64] = LaunchCast<float, uint64_t>;
mode_map[kNumberTypeFloat32][kNumberTypeBool] = LaunchCast<float, bool>;

mode_map[kNumberTypeFloat64][kNumberTypeFloat16] = LaunchCast<double, float16>;
mode_map[kNumberTypeFloat64][kNumberTypeFloat32] = LaunchCast<double, float>;
mode_map[kNumberTypeFloat64][kNumberTypeFloat64] = LaunchCast<double, double>;
mode_map[kNumberTypeFloat64][kNumberTypeInt8] = LaunchCast<double, int8_t>;
mode_map[kNumberTypeFloat64][kNumberTypeInt16] = LaunchCast<double, int16_t>;
mode_map[kNumberTypeFloat64][kNumberTypeInt32] = LaunchCast<double, int32_t>;
mode_map[kNumberTypeFloat64][kNumberTypeInt64] = LaunchCast<double, int64_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt8] = LaunchCast<double, uint8_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt16] = LaunchCast<double, uint16_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt32] = LaunchCast<double, uint32_t>;
mode_map[kNumberTypeFloat64][kNumberTypeUInt64] = LaunchCast<double, uint64_t>;
mode_map[kNumberTypeFloat64][kNumberTypeBool] = LaunchCast<double, bool>;

mode_map[kNumberTypeInt8][kNumberTypeFloat16] = LaunchCast<int8_t, float16>;
mode_map[kNumberTypeInt8][kNumberTypeFloat32] = LaunchCast<int8_t, float>;
mode_map[kNumberTypeInt8][kNumberTypeFloat64] = LaunchCast<int8_t, double>;
mode_map[kNumberTypeInt8][kNumberTypeInt8] = LaunchCast<int8_t, int8_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt16] = LaunchCast<int8_t, int16_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt32] = LaunchCast<int8_t, int32_t>;
mode_map[kNumberTypeInt8][kNumberTypeInt64] = LaunchCast<int8_t, int64_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt8] = LaunchCast<int8_t, uint8_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt16] = LaunchCast<int8_t, uint16_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt32] = LaunchCast<int8_t, uint32_t>;
mode_map[kNumberTypeInt8][kNumberTypeUInt64] = LaunchCast<int8_t, uint64_t>;
mode_map[kNumberTypeInt8][kNumberTypeBool] = LaunchCast<int8_t, bool>;

mode_map[kNumberTypeInt16][kNumberTypeFloat16] = LaunchCast<int16_t, float16>;
mode_map[kNumberTypeInt16][kNumberTypeFloat32] = LaunchCast<int16_t, float>;
mode_map[kNumberTypeInt16][kNumberTypeFloat64] = LaunchCast<int16_t, double>;
mode_map[kNumberTypeInt16][kNumberTypeInt8] = LaunchCast<int16_t, int8_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt16] = LaunchCast<int16_t, int16_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast<int16_t, int32_t>;
mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast<int16_t, int64_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt8] = LaunchCast<int16_t, uint8_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt16] = LaunchCast<int16_t, uint16_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt32] = LaunchCast<int16_t, uint32_t>;
mode_map[kNumberTypeInt16][kNumberTypeUInt64] = LaunchCast<int16_t, uint64_t>;
mode_map[kNumberTypeInt16][kNumberTypeBool] = LaunchCast<int16_t, bool>;

mode_map[kNumberTypeInt32][kNumberTypeFloat16] = LaunchCast<int32_t, float16>;
mode_map[kNumberTypeInt32][kNumberTypeFloat32] = LaunchCast<int32_t, float>;
mode_map[kNumberTypeInt32][kNumberTypeFloat64] = LaunchCast<int32_t, double>;
mode_map[kNumberTypeInt32][kNumberTypeInt8] = LaunchCast<int32_t, int8_t>;
mode_map[kNumberTypeInt32][kNumberTypeInt16] = LaunchCast<int32_t, int16_t>;
mode_map[kNumberTypeInt32][kNumberTypeInt32] = LaunchCast<int32_t, int32_t>;
mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast<int32_t, int64_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt8] = LaunchCast<int32_t, uint8_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt16] = LaunchCast<int32_t, uint16_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt32] = LaunchCast<int32_t, uint32_t>;
mode_map[kNumberTypeInt32][kNumberTypeUInt64] = LaunchCast<int32_t, uint64_t>;
mode_map[kNumberTypeInt32][kNumberTypeBool] = LaunchCast<int32_t, bool>;

mode_map[kNumberTypeInt64][kNumberTypeFloat16] = LaunchCast<int64_t, float16>;
mode_map[kNumberTypeInt64][kNumberTypeFloat32] = LaunchCast<int64_t, float>;
mode_map[kNumberTypeInt64][kNumberTypeFloat64] = LaunchCast<int64_t, double>;
mode_map[kNumberTypeInt64][kNumberTypeInt8] = LaunchCast<int64_t, int8_t>;
mode_map[kNumberTypeInt64][kNumberTypeInt16] = LaunchCast<int64_t, int16_t>;
mode_map[kNumberTypeInt64][kNumberTypeInt32] = LaunchCast<int64_t, int32_t>;
mode_map[kNumberTypeInt64][kNumberTypeInt64] = LaunchCast<int64_t, int64_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt8] = LaunchCast<int64_t, uint8_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt16] = LaunchCast<int64_t, uint16_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt32] = LaunchCast<int64_t, uint32_t>;
mode_map[kNumberTypeInt64][kNumberTypeUInt64] = LaunchCast<int64_t, uint64_t>;
mode_map[kNumberTypeInt64][kNumberTypeBool] = LaunchCast<int64_t, bool>;

mode_map[kNumberTypeUInt8][kNumberTypeFloat16] = LaunchCast<uint8_t, float16>;
mode_map[kNumberTypeUInt8][kNumberTypeFloat32] = LaunchCast<uint8_t, float>;
mode_map[kNumberTypeUInt8][kNumberTypeFloat64] = LaunchCast<uint8_t, double>;
mode_map[kNumberTypeUInt8][kNumberTypeInt8] = LaunchCast<uint8_t, int8_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt16] = LaunchCast<uint8_t, int16_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt32] = LaunchCast<uint8_t, int32_t>;
mode_map[kNumberTypeUInt8][kNumberTypeInt64] = LaunchCast<uint8_t, int64_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt8] = LaunchCast<uint8_t, uint8_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt16] = LaunchCast<uint8_t, uint16_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt32] = LaunchCast<uint8_t, uint32_t>;
mode_map[kNumberTypeUInt8][kNumberTypeUInt64] = LaunchCast<uint8_t, uint64_t>;
mode_map[kNumberTypeUInt8][kNumberTypeBool] = LaunchCast<uint8_t, bool>;

mode_map[kNumberTypeUInt16][kNumberTypeFloat16] = LaunchCast<uint16_t, float16>;
mode_map[kNumberTypeUInt16][kNumberTypeFloat32] = LaunchCast<uint16_t, float>;
mode_map[kNumberTypeUInt16][kNumberTypeFloat64] = LaunchCast<uint16_t, double>;
mode_map[kNumberTypeUInt16][kNumberTypeInt8] = LaunchCast<uint16_t, int8_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt16] = LaunchCast<uint16_t, int16_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt32] = LaunchCast<uint16_t, int32_t>;
mode_map[kNumberTypeUInt16][kNumberTypeInt64] = LaunchCast<uint16_t, int64_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt8] = LaunchCast<uint16_t, uint8_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt16] = LaunchCast<uint16_t, uint16_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt32] = LaunchCast<uint16_t, uint32_t>;
mode_map[kNumberTypeUInt16][kNumberTypeUInt64] = LaunchCast<uint16_t, uint64_t>;
mode_map[kNumberTypeUInt16][kNumberTypeBool] = LaunchCast<uint16_t, bool>;

mode_map[kNumberTypeUInt32][kNumberTypeFloat16] = LaunchCast<uint32_t, float16>;
mode_map[kNumberTypeUInt32][kNumberTypeFloat32] = LaunchCast<uint32_t, float>;
mode_map[kNumberTypeUInt32][kNumberTypeFloat64] = LaunchCast<uint32_t, double>;
mode_map[kNumberTypeUInt32][kNumberTypeInt8] = LaunchCast<uint32_t, int8_t>;
mode_map[kNumberTypeUInt32][kNumberTypeInt16] = LaunchCast<uint32_t, int16_t>;
mode_map[kNumberTypeUInt32][kNumberTypeInt32] = LaunchCast<uint32_t, int32_t>;
mode_map[kNumberTypeUInt32][kNumberTypeInt64] = LaunchCast<uint32_t, int64_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt8] = LaunchCast<uint32_t, uint8_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt16] = LaunchCast<uint32_t, uint16_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt32] = LaunchCast<uint32_t, uint32_t>;
mode_map[kNumberTypeUInt32][kNumberTypeUInt64] = LaunchCast<uint32_t, uint64_t>;
mode_map[kNumberTypeUInt32][kNumberTypeBool] = LaunchCast<uint32_t, bool>;

mode_map[kNumberTypeUInt64][kNumberTypeFloat16] = LaunchCast<uint64_t, float16>;
mode_map[kNumberTypeUInt64][kNumberTypeFloat32] = LaunchCast<uint64_t, float>;
mode_map[kNumberTypeUInt64][kNumberTypeFloat64] = LaunchCast<uint64_t, double>;
mode_map[kNumberTypeUInt64][kNumberTypeInt8] = LaunchCast<uint64_t, int8_t>;
mode_map[kNumberTypeUInt64][kNumberTypeInt16] = LaunchCast<uint64_t, int16_t>;
mode_map[kNumberTypeUInt64][kNumberTypeInt32] = LaunchCast<uint64_t, int32_t>;
mode_map[kNumberTypeUInt64][kNumberTypeInt64] = LaunchCast<uint64_t, int64_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt8] = LaunchCast<uint64_t, uint8_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt16] = LaunchCast<uint64_t, uint16_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt32] = LaunchCast<uint64_t, uint32_t>;
mode_map[kNumberTypeUInt64][kNumberTypeUInt64] = LaunchCast<uint64_t, uint64_t>;
mode_map[kNumberTypeUInt64][kNumberTypeBool] = LaunchCast<uint64_t, bool>;
template <typename S, typename T>
bool CastCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
S *input = reinterpret_cast<S *>(inputs[0]->addr);
T *output = reinterpret_cast<T *>(outputs[0]->addr);
MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name();

mode_map[source_dtype][target_dtype](inputs, outputs);
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
Cast<S, T>(input, output, lens);
return true;
}
} // namespace kernel


+ 300
- 155
mindspore/ccsrc/backend/kernel_compiler/cpu/cast_cpu_kernel.h View File

@@ -23,6 +23,7 @@

namespace mindspore {
namespace kernel {
template <typename S, typename T>
class CastCPUKernel : public CPUKernel {
public:
CastCPUKernel() = default;
@@ -38,161 +39,305 @@ class CastCPUKernel : public CPUKernel {
TypeId target_dtype{kTypeUnknown};
};

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel);

MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel);
MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
bool, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
bool, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
bool, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
bool, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
bool, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
bool, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
bool, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
bool, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
bool, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
bool, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
bool, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
bool, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, float16, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, float16, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, float16, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
float16, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16),
CastCPUKernel, float16, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32),
CastCPUKernel, float16, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64),
CastCPUKernel, float16, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8),
CastCPUKernel, float16, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, float16, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, float16, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, float16, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
float16, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, float, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, float, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, float, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
float, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16),
CastCPUKernel, float, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32),
CastCPUKernel, float, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64),
CastCPUKernel, float, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8),
CastCPUKernel, float, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, float, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, float, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, float, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
float, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, double, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, double, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, double, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
double, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16),
CastCPUKernel, double, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32),
CastCPUKernel, double, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64),
CastCPUKernel, double, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8),
CastCPUKernel, double, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, double, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, double, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, double, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
double, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel,
int8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel,
int8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel,
int8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int8_t, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, int16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, int16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, int16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int16_t, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, int32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, int32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, int32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int32_t, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, int64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, int64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, int64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
int64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
int64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
int64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
int64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
int64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
int64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
int64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
int64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
int64_t, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint8_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint8_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint8_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint8_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint8_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint8_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint8_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint8_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel,
uint8_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel,
uint8_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel,
uint8_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint8_t, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint16_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint16_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint16_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint16_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint16_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint16_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint16_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint16_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, uint16_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, uint16_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, uint16_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint16_t, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint32_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint32_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint32_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint32_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint32_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint32_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint32_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint32_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, uint32_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, uint32_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, uint32_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint32_t, bool);

MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16),
CastCPUKernel, uint64_t, float16);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32),
CastCPUKernel, uint64_t, float);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64),
CastCPUKernel, uint64_t, double);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel,
uint64_t, int8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel,
uint64_t, int16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel,
uint64_t, int32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel,
uint64_t, int64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel,
uint64_t, uint8_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16),
CastCPUKernel, uint64_t, uint16_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32),
CastCPUKernel, uint64_t, uint32_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
CastCPUKernel, uint64_t, uint64_t);
MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel,
uint64_t, bool);
} // namespace kernel
} // namespace mindspore



+ 38
- 15
mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_cpu_kernel.cc View File

@@ -14,8 +14,11 @@
* limitations under the License.
*/

#include <cmath>
#include "backend/kernel_compiler/cpu/layer_norm_cpu_kernel.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"

namespace mindspore {
namespace kernel {
@@ -72,23 +75,43 @@ void LayerNormCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con
auto y = reinterpret_cast<T *>(outputs[0]->addr);
auto mean = reinterpret_cast<T *>(outputs[1]->addr);
auto var = reinterpret_cast<T *>(outputs[2]->addr);
for (size_t i = 0; i < block_num_; ++i) {
T sum = (T)0.0;
T square_sum = (T)0.0;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
sum += x[j];
square_sum += x[j] * x[j];
}
T block_mean = sum / block_size_;
T block_var = square_sum / block_size_ - block_mean * block_mean;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast<double>(block_var) + eps_) * gamma[param_shift] +
beta[param_shift];
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (block_num_ < thread_num) {
thread_num = block_num_;
}
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
auto task = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num); ++c) {
if (c * thread_num + start >= block_num_) {
continue;
}
size_t i = c * thread_num + start;
T sum = (T)0.0;
T square_sum = (T)0.0;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
sum += x[j];
square_sum += x[j] * x[j];
}
T block_mean = sum / block_size_;
T block_var = square_sum / block_size_ - block_mean * block_mean;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast<double>(block_var) + eps_) * gamma[param_shift] +
beta[param_shift];
}
mean[i] = block_mean;
var[i] = block_var;
}
mean[i] = block_mean;
var[i] = block_var;
};
for (size_t i = 0; i < thread_num; ++i) {
auto block = [&, i]() {
task(i, i + 1);
return common::SUCCESS;
};
tasks.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
}

void LayerNormCPUKernel::CheckParam(const CNodePtr &kernel_node) {


+ 68
- 32
mindspore/ccsrc/backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.cc View File

@@ -15,7 +15,9 @@
*/

#include "backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.h"
#include "backend/kernel_compiler/common_utils.h"
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"

namespace mindspore {
namespace kernel {
@@ -73,41 +75,75 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs,
auto dx = reinterpret_cast<T *>(outputs[0]->addr);
auto dg = reinterpret_cast<T *>(outputs[1]->addr);
auto db = reinterpret_cast<T *>(outputs[2]->addr);

for (size_t i = 0; i < param_num_; ++i) {
T dgamma = (T)0.0;
T dbeta = (T)0.0;
for (size_t j = i; j < param_size_ * param_num_; j += param_num_) {
auto norm_shift = static_cast<int>(j / block_size_);
dgamma += dy[j] * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]);
dbeta += dy[j];
}
dg[i] = dgamma;
db[i] = dbeta;
}
for (size_t i = 0; i < block_num_; ++i) {
T sum1 = (T)0.0;
T sum2 = (T)0.0;
T sum3 = (T)0.0;
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto dxm = x[j] - mean[norm_shift];
auto dyg = dy[j] * gamma[param_shift];
sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -1.5);
sum2 += dyg;
sum3 += (T)(-2.0) * dxm;
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
auto thread_num1 = param_num_ < thread_num ? param_num_ : thread_num;
std::vector<common::Task> tasks1;
tasks1.reserve(thread_num1);
auto thread_num2 = block_num_ < thread_num ? block_num_ : thread_num;
std::vector<common::Task> tasks2;
tasks2.reserve(thread_num2);
auto task1 = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(param_num_) / thread_num1); ++c) {
if (c * thread_num1 + start >= param_num_) {
continue;
}
size_t param_index = c * thread_num1 + start;
T dgamma = (T)0.0;
T dbeta = (T)0.0;
for (size_t j = param_index; j < param_size_ * param_num_; j += param_num_) {
auto norm_shift = static_cast<int>(j / block_size_);
dgamma += dy[j] * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]);
dbeta += dy[j];
}
dg[param_index] = dgamma;
db[param_index] = dbeta;
}
for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto var_sqrt = (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5);
auto dx1 = dy[j] * gamma[param_shift] * var_sqrt;
auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]);
auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_);
dx[j] = dx1 + dx2 + dx3;
};
auto task2 = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num2); ++c) {
if (c * thread_num2 + start >= block_num_) {
continue;
}
size_t block_index = c * thread_num2 + start;
T sum1 = (T)0.0;
T sum2 = (T)0.0;
T sum3 = (T)0.0;
for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto dxm = x[j] - mean[norm_shift];
auto dyg = dy[j] * gamma[param_shift];
sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -1.5);
sum2 += dyg;
sum3 += (T)(-2.0) * dxm;
}
for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) {
auto param_shift = j % param_num_;
auto norm_shift = static_cast<int>(j / block_size_);
auto var_sqrt = (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5);
auto dx1 = dy[j] * gamma[param_shift] * var_sqrt;
auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]);
auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_);
dx[j] = dx1 + dx2 + dx3;
}
}
};
for (size_t i = 0; i < thread_num1; ++i) {
auto block = [&, i]() {
task1(i, i + 1);
return common::SUCCESS;
};
tasks1.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks1);
for (size_t i = 0; i < thread_num2; ++i) {
auto block = [&, i]() {
task2(i, i + 1);
return common::SUCCESS;
};
tasks2.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks2);
}

void LayerNormGradCPUKernel::CheckParam(const CNodePtr &kernel_node) {


+ 29
- 8
mindspore/ccsrc/backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.cc View File

@@ -16,6 +16,7 @@

#include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h"
#include <string>
#include <cmath>
#include "runtime/device/cpu/cpu_device_address.h"
#include "common/thread_pool.h"

@@ -78,17 +79,37 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in
MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret;
return false;
}
for (size_t i = 0; i < unit_num_; ++i) {
size_t j = i / input_dim1_;
size_t k = i % input_dim1_;
size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum();
if (unit_num_ < thread_num) {
thread_num = unit_num_;
}
std::vector<common::Task> tasks;
tasks.reserve(thread_num);
auto task = [&](size_t start, size_t end) {
for (size_t c = 0; c < ceil(static_cast<double>(unit_num_) / thread_num); ++c) {
if (c * thread_num + start >= unit_num_) {
continue;
}
size_t i = c * thread_num + start;
size_t j = i / input_dim1_;
size_t k = i % input_dim1_;

T index = indices_addr[j];
if (index < 0 || index >= SizeToInt(output_dim0_)) {
continue;
T index = indices_addr[j];
if (index < 0 || index >= SizeToInt(output_dim0_)) {
continue;
}
size_t output_index = index * output_dim1_ + k;
output_addr[output_index] += input_addr[i];
}
size_t output_index = index * output_dim1_ + k;
output_addr[output_index] += input_addr[i];
};
for (size_t t = 0; t < thread_num; ++t) {
auto block = [&, t]() {
task(t, t + 1);
return common::SUCCESS;
};
tasks.emplace_back(block);
}
common::ThreadPool::GetInstance().SyncRun(tasks);
return true;
}
} // namespace kernel


Loading…
Cancel
Save