| @@ -16,7 +16,6 @@ | |||
| #include "backend/kernel_compiler/cpu/adam_cpu_kernel.h" | |||
| #include <cmath> | |||
| #include <thread> | |||
| #include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "utils/ms_utils.h" | |||
| @@ -25,16 +24,19 @@ namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| void AdamCPUKernel::LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient, | |||
| size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| m[i] += (gradient[i] - m[i]) * (1 - beta1); | |||
| v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2); | |||
| if (use_nesterov) { | |||
| var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon); | |||
| } else { | |||
| var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); | |||
| size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| m[i] += (gradient[i] - m[i]) * (1 - beta1); | |||
| v[i] += (gradient[i] * gradient[i] - v[i]) * (1 - beta2); | |||
| if (use_nesterov) { | |||
| var[i] -= lr * (m[i] * beta1 + (1 - beta1) * gradient[i]) / (std::sqrt(v[i]) + epsilon); | |||
| } else { | |||
| var[i] -= lr * m[i] / (std::sqrt(v[i]) + epsilon); | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| void AdamCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| @@ -84,31 +86,7 @@ bool AdamCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| // multithreading | |||
| size_t lens = inputs[0]->size > 0 ? static_cast<size_t>(inputs[0]->size / sizeof(float)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| if (thread_num < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||
| return false; | |||
| } | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return false; | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(&AdamCPUKernel::LaunchAdam<float>, this, var, m, v, new_lr, beta1, beta2, epsilon, | |||
| gradient, start, end)); | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| } | |||
| LaunchAdam<float>(var, m, v, new_lr, beta1, beta2, epsilon, gradient, lens); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -29,7 +29,7 @@ class AdamCPUKernel : public CPUKernel { | |||
| ~AdamCPUKernel() override = default; | |||
| template <typename T> | |||
| void LaunchAdam(T *var, T *m, T *v, float lr, float beta1, float beta2, float epsilon, const T *gradient, | |||
| size_t start, size_t end); | |||
| size_t size); | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include <cmath> | |||
| #include <string> | |||
| #include <thread> | |||
| #include <map> | |||
| #include "backend/kernel_compiler/cpu/arithmetic_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| @@ -23,227 +22,285 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = input1[i] + input2[i]; | |||
| input1[i] = out[i]; | |||
| } | |||
| void ArithmeticCPUKernel::AssignAdd(T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = input1[i] + input2[i]; | |||
| input1[i] = out[i]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] + input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::Add(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] + input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] - input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::Sub(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] - input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] * input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::Mul(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] * input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto dividend = input1[idx[0]]; | |||
| auto divisor = input2[idx[1]]; | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| void ArithmeticCPUKernel::RealDiv(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto dividend = input1[idx[0]]; | |||
| auto divisor = input2[idx[1]]; | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| out[i] = dividend / divisor; | |||
| } | |||
| out[i] = dividend / divisor; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto dividend = input1[idx[0]]; | |||
| auto divisor = input2[idx[1]]; | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto dividend = input1[idx[0]]; | |||
| auto divisor = input2[idx[1]]; | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| out[i] = dividend / divisor; | |||
| } | |||
| out[i] = dividend / divisor; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto dividend = input1[idx[0]]; | |||
| auto divisor = input2[idx[1]]; | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto dividend = input1[idx[0]]; | |||
| auto divisor = input2[idx[1]]; | |||
| if (divisor == 0) { | |||
| if (dividend == 0) { | |||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| } | |||
| if (std::numeric_limits<T>::has_infinity) { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||
| } else { | |||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||
| } | |||
| continue; | |||
| out[i] = floor(dividend / divisor); | |||
| } | |||
| out[i] = floor(dividend / divisor); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto x = static_cast<double>(input1[idx[0]]); | |||
| auto y = static_cast<double>(input2[idx[1]]); | |||
| auto data_div = x / y; | |||
| auto data_div_min = data_div < 0.0 ? data_div : 0.0; | |||
| auto data_div_max = data_div > 0.0 ? data_div : 0.0; | |||
| auto data_div_max_floor = floor(data_div_max); | |||
| auto data_div_min_ceil = ceil(data_div_min); | |||
| auto data_div_res = data_div_max_floor + data_div_min_ceil; | |||
| out[i] = static_cast<T>(x - data_div_res * y); | |||
| } | |||
| void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto x = static_cast<double>(input1[idx[0]]); | |||
| auto y = static_cast<double>(input2[idx[1]]); | |||
| auto data_div = x / y; | |||
| auto data_div_min = data_div < 0.0 ? data_div : 0.0; | |||
| auto data_div_max = data_div > 0.0 ? data_div : 0.0; | |||
| auto data_div_max_floor = floor(data_div_max); | |||
| auto data_div_min_ceil = ceil(data_div_min); | |||
| auto data_div_res = data_div_max_floor + data_div_min_ceil; | |||
| out[i] = static_cast<T>(x - data_div_res * y); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto x = static_cast<double>(input1[idx[0]]); | |||
| auto y = static_cast<double>(input2[idx[1]]); | |||
| out[i] = static_cast<T>(std::pow(x, y)); | |||
| } | |||
| void ArithmeticCPUKernel::Pow(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| auto x = static_cast<double>(input1[idx[0]]); | |||
| auto y = static_cast<double>(input2[idx[1]]); | |||
| out[i] = static_cast<T>(std::pow(x, y)); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Less(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] < input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::Less(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] < input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] == input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::Equal(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] == input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] != input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::NotEqual(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] != input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] && input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::LogicalAnd(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] && input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] || input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::LogicalOr(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] || input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| T diff = input1[idx[0]] - input2[idx[1]]; | |||
| out[i] = diff * diff; | |||
| } | |||
| void ArithmeticCPUKernel::SquaredDifference(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| T diff = input1[idx[0]] - input2[idx[1]]; | |||
| out[i] = diff * diff; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] > input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::Greater(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] > input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] >= input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::GreaterEqual(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] >= input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] <= input2[idx[1]]; | |||
| } | |||
| void ArithmeticCPUKernel::LessEqual(const T *input1, const T *input2, bool *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = input1[idx[0]] <= input2[idx[1]]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = atan2(input1[idx[0]], input2[idx[1]]); | |||
| } | |||
| void ArithmeticCPUKernel::Atan2(const T *input1, const T *input2, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| std::vector<size_t> idx; | |||
| GenIndex(i, &idx); | |||
| out[i] = atan2(input1[idx[0]], input2[idx[1]]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| static const std::map<std::string, OperateType> kArithmeticBinOpTypeMap = { | |||
| {prim::kPrimGreater->name(), GREATER}, | |||
| {prim::kPrimAdd->name(), ADD}, | |||
| @@ -352,49 +409,25 @@ void ArithmeticCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &input | |||
| T *input1 = reinterpret_cast<T *>(inputs[0]->addr); | |||
| T *input2 = reinterpret_cast<T *>(inputs[1]->addr); | |||
| bool *output = reinterpret_cast<bool *>(outputs[0]->addr); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(bool)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| if (thread_num < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||
| return; | |||
| } | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| if (operate_type_ == LESS) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Less<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == EQUAL) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Equal<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == NOTEQUAL) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::NotEqual<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == GREATER) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Greater<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == GREATEREQUAL) { | |||
| threads.emplace_back( | |||
| std::thread(&ArithmeticCPUKernel::GreaterEqual<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == LESSEQUAL) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::LessEqual<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == LOGICALAND) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalAnd<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == LOGICALOR) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::LogicalOr<T>, this, input1, input2, output, start, end)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| if (operate_type_ == LESS) { | |||
| Less<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == EQUAL) { | |||
| Equal<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == NOTEQUAL) { | |||
| NotEqual<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == GREATER) { | |||
| Greater<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == GREATEREQUAL) { | |||
| GreaterEqual<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == LESSEQUAL) { | |||
| LessEqual<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == LOGICALAND) { | |||
| LogicalAnd<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == LOGICALOR) { | |||
| LogicalOr<T>(input1, input2, output, lens); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| } | |||
| @@ -409,53 +442,30 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co | |||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| if (thread_num < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||
| return; | |||
| } | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| if (operate_type_ == ADD) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Add<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == SUB) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Sub<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == MUL) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mul<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == REALDIV) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::RealDiv<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == DIV) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == FLOORDIV) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::FloorDiv<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == MOD) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mod<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == POW) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Pow<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ASSIGNADD) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::AssignAdd<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == ATAN2) { | |||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Atan2<T>, this, input1, input2, output, start, end)); | |||
| } else if (operate_type_ == SQUAREDDIFFERENCE) { | |||
| threads.emplace_back( | |||
| std::thread(&ArithmeticCPUKernel::SquaredDifference<T>, this, input1, input2, output, start, end)); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| if (operate_type_ == ADD) { | |||
| Add<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == SUB) { | |||
| Sub<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == MUL) { | |||
| Mul<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == REALDIV) { | |||
| RealDiv<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == DIV) { | |||
| Div<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == FLOORDIV) { | |||
| FloorDiv<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == MOD) { | |||
| Mod<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == POW) { | |||
| Pow<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == ASSIGNADD) { | |||
| AssignAdd<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == ATAN2) { | |||
| Atan2<T>(input1, input2, output, lens); | |||
| } else if (operate_type_ == SQUAREDDIFFERENCE) { | |||
| SquaredDifference<T>(input1, input2, output, lens); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| @@ -40,43 +40,43 @@ class ArithmeticCPUKernel : public CPUKernel { | |||
| private: | |||
| void GenIndex(size_t num, std::vector<size_t> *tmp); | |||
| template <typename T> | |||
| void Sub(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void Sub(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Add(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void Add(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Mul(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void Mul(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void RealDiv(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void RealDiv(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void Div(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void FloorDiv(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Mod(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void Mod(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Pow(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void Pow(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void AssignAdd(T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void AssignAdd(T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Atan2(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void Atan2(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Less(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void Less(const T *input1, const T *input2, bool *out, size_t size); | |||
| template <typename T> | |||
| void Equal(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void Equal(const T *input1, const T *input2, bool *out, size_t size); | |||
| template <typename T> | |||
| void NotEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void NotEqual(const T *input1, const T *input2, bool *out, size_t size); | |||
| template <typename T> | |||
| void SquaredDifference(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||
| void SquaredDifference(const T *input1, const T *input2, T *out, size_t size); | |||
| template <typename T> | |||
| void Greater(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void Greater(const T *input1, const T *input2, bool *out, size_t size); | |||
| template <typename T> | |||
| void GreaterEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void GreaterEqual(const T *input1, const T *input2, bool *out, size_t size); | |||
| template <typename T> | |||
| void LessEqual(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void LessEqual(const T *input1, const T *input2, bool *out, size_t size); | |||
| template <typename T> | |||
| void LogicalAnd(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void LogicalAnd(const T *input1, const T *input2, bool *out, size_t size); | |||
| template <typename T> | |||
| void LogicalOr(const T *input1, const T *input2, bool *out, size_t start, size_t end); | |||
| void LogicalOr(const T *input1, const T *input2, bool *out, size_t size); | |||
| std::vector<size_t> input_shape0_; | |||
| std::vector<size_t> input_shape1_; | |||
| std::vector<size_t> input_element_num0_; | |||
| @@ -24,152 +24,212 @@ namespace mindspore { | |||
| namespace kernel { | |||
| namespace { | |||
| template <typename T> | |||
| void Square(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = in[i] * in[i]; | |||
| } | |||
| void Square(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = in[i] * in[i]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Sign(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| if (in[i] < 0) { | |||
| out[i] = -1; | |||
| } else if (in[i] > 0) { | |||
| out[i] = 1; | |||
| } else { | |||
| out[i] = 0; | |||
| void Sign(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| if (in[i] < 0) { | |||
| out[i] = -1; | |||
| } else if (in[i] > 0) { | |||
| out[i] = 1; | |||
| } else { | |||
| out[i] = 0; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Neg(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = -in[i]; | |||
| } | |||
| void Neg(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = -in[i]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void LogicalNot(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = !in[i]; | |||
| } | |||
| void LogicalNot(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = !in[i]; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void OnesLike(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(1); | |||
| } | |||
| void OnesLike(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(1); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ZerosLike(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(0); | |||
| } | |||
| void ZerosLike(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(0); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Floor(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(floor(in[i])); | |||
| } | |||
| void Floor(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(floor(in[i])); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Reciprocal(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(1.0 / in[i]); | |||
| } | |||
| void Reciprocal(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(1.0 / in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Gelu(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| T x = in[i]; | |||
| auto double_x = static_cast<T>(x); | |||
| T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x)); | |||
| out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; | |||
| } | |||
| void Gelu(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| T x = in[i]; | |||
| auto double_x = static_cast<T>(x); | |||
| T tanh_res = (T)std::tanh(0.7978845608 * (double_x + 0.044715 * double_x * double_x * double_x)); | |||
| out[i] = x * ((T)1.0 + tanh_res) / (T)2.0; | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Asin(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = asin(in[i]); | |||
| } | |||
| void Asin(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = asin(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void ACos(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = acos(in[i]); | |||
| } | |||
| void ACos(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = acos(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Atan(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = atan(in[i]); | |||
| } | |||
| void Atan(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = atan(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Sin(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = sin(in[i]); | |||
| } | |||
| void Sin(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = sin(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Cos(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = cos(in[i]); | |||
| } | |||
| void Cos(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = cos(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Tan(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = tan(in[i]); | |||
| } | |||
| void Tan(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = tan(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Sinh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = sinh(in[i]); | |||
| } | |||
| void Sinh(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = sinh(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Cosh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = cosh(in[i]); | |||
| } | |||
| void Cosh(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = cosh(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Asinh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = asinh(in[i]); | |||
| } | |||
| void Asinh(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = asinh(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Acosh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = acosh(in[i]); | |||
| } | |||
| void Acosh(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = acosh(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename T> | |||
| void Atanh(const T *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = atanh(in[i]); | |||
| } | |||
| void Atanh(const T *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = atanh(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| } // namespace | |||
| @@ -223,79 +283,31 @@ void ArithmeticSelfCPUKernel::LaunchKernelLogic(const std::vector<AddressPtr> &i | |||
| T *input = reinterpret_cast<T *>(inputs[0]->addr); | |||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| if (thread_num < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||
| return; | |||
| } | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| if (operate_type_ == LOGICALNOT) { | |||
| threads.emplace_back(std::thread(LogicalNot<T>, input, output, start, end)); | |||
| } | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| } | |||
| LogicalNot<T>(input, output, lens); | |||
| return; | |||
| } | |||
| template <typename T> | |||
| void ArithmeticSelfCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| const std::vector<AddressPtr> &outputs) { | |||
| if (target_dtype_ == kNumberTypeBool) { | |||
| LaunchKernelLogic<T>(inputs, outputs); | |||
| return; | |||
| } | |||
| T *input = reinterpret_cast<T *>(inputs[0]->addr); | |||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| if (thread_num < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||
| return; | |||
| } | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| static const std::map<OperateType, std::function<void(const T *in, T *out, size_t start, size_t end)>> | |||
| kArithmeticOpFuncMap = {{SQUARE, Square<T>}, {SIGN, Sign<T>}, | |||
| {NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>}, | |||
| {ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>}, | |||
| {FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>}, | |||
| {GELU, Gelu<T>}, {SIN, Sin<T>}, | |||
| {COS, Cos<T>}, {TAN, Tan<T>}, | |||
| {ASIN, Asin<T>}, {ACOS, ACos<T>}, | |||
| {ATAN, Atan<T>}, {SINH, Sinh<T>}, | |||
| {COSH, Cosh<T>}, {ASINH, Asinh<T>}, | |||
| {ACOSH, Acosh<T>}, {ATANH, Atanh<T>}}; | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(kArithmeticOpFuncMap.at(operate_type_), input, output, start, end)); | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| static const std::map<OperateType, std::function<void(const T *in, T *out, size_t size)>> kArithmeticOpFuncMap = { | |||
| {SQUARE, Square<T>}, {SIGN, Sign<T>}, | |||
| {NEG, Neg<T>}, {LOGICALNOT, LogicalNot<T>}, | |||
| {ONESLIKE, OnesLike<T>}, {ZEROSLIKE, ZerosLike<T>}, | |||
| {FLOOR, Floor<T>}, {RECIPROCAL, Reciprocal<T>}, | |||
| {GELU, Gelu<T>}, {SIN, Sin<T>}, | |||
| {COS, Cos<T>}, {TAN, Tan<T>}, | |||
| {ASIN, Asin<T>}, {ACOS, ACos<T>}, | |||
| {ATAN, Atan<T>}, {SINH, Sinh<T>}, | |||
| {COSH, Cosh<T>}, {ASINH, Asinh<T>}, | |||
| {ACOSH, Acosh<T>}, {ATANH, Atanh<T>}}; | |||
| if (kArithmeticOpFuncMap.find(operate_type_) != kArithmeticOpFuncMap.end()) { | |||
| kArithmeticOpFuncMap.at(operate_type_)(input, output, lens); | |||
| } else { | |||
| MS_LOG(EXCEPTION) << "Not support " << operate_type_; | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| @@ -34,7 +34,6 @@ class ArithmeticSelfCPUKernel : public CPUKernel { | |||
| template <typename T> | |||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| template <typename T> | |||
| void LaunchKernelLogic(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| @@ -16,220 +16,38 @@ | |||
| #include <cmath> | |||
| #include <map> | |||
| #include <string> | |||
| #include <thread> | |||
| #include "backend/kernel_compiler/cpu/cast_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename S, typename T> | |||
| void Cast(const S *in, T *out, size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(in[i]); | |||
| } | |||
| void Cast(const S *in, T *out, size_t size) { | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t i = start; i < end; i++) { | |||
| out[i] = static_cast<T>(in[i]); | |||
| } | |||
| }; | |||
| CPUKernelUtils::ParallelFor(task, size); | |||
| } | |||
| template <typename S, typename T> | |||
| void LaunchCast(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &outputs) { | |||
| S *input = reinterpret_cast<S *>(inputs[0]->addr); | |||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name(); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| auto max_thread_num = std::thread::hardware_concurrency(); | |||
| size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num; | |||
| MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num; | |||
| std::vector<std::thread> threads; | |||
| if (thread_num < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: thread_num " << thread_num; | |||
| return; | |||
| } | |||
| threads.reserve(thread_num); | |||
| size_t start = 0; | |||
| size_t once_compute_size = (lens + thread_num - 1) / thread_num; | |||
| if (once_compute_size < 1) { | |||
| MS_LOG(ERROR) << "Invalid value: once_compute_size " << once_compute_size; | |||
| return; | |||
| } | |||
| while (start < lens) { | |||
| size_t end = (start + once_compute_size) > lens ? lens : (start + once_compute_size); | |||
| threads.emplace_back(std::thread(Cast<S, T>, input, output, start, end)); | |||
| start += once_compute_size; | |||
| } | |||
| for (size_t i = 0; i < threads.size(); ++i) { | |||
| threads[i].join(); | |||
| } | |||
| } | |||
| void CastCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| void CastCPUKernel<S, T>::InitKernel(const CNodePtr &kernel_node) { | |||
| MS_EXCEPTION_IF_NULL(kernel_node); | |||
| source_dtype = AnfAlgo::GetPrevNodeOutputDeviceDataType(kernel_node, 0); | |||
| target_dtype = AnfAlgo::GetOutputInferDataType(kernel_node, 0); | |||
| } | |||
| bool CastCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| using TypePair = | |||
| std::function<void(const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>; | |||
| std::map<TypeId, std::map<TypeId, TypePair>> mode_map; | |||
| mode_map[kNumberTypeBool][kNumberTypeFloat16] = LaunchCast<bool, float16>; | |||
| mode_map[kNumberTypeBool][kNumberTypeFloat32] = LaunchCast<bool, float>; | |||
| mode_map[kNumberTypeBool][kNumberTypeFloat64] = LaunchCast<bool, double>; | |||
| mode_map[kNumberTypeBool][kNumberTypeInt8] = LaunchCast<bool, int8_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeInt16] = LaunchCast<bool, int16_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeInt32] = LaunchCast<bool, int32_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeInt64] = LaunchCast<bool, int64_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeUInt8] = LaunchCast<bool, uint8_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeUInt16] = LaunchCast<bool, uint16_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeUInt32] = LaunchCast<bool, uint32_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeUInt64] = LaunchCast<bool, uint64_t>; | |||
| mode_map[kNumberTypeBool][kNumberTypeBool] = LaunchCast<bool, bool>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeFloat16] = LaunchCast<float16, float16>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeFloat32] = LaunchCast<float16, float>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeFloat64] = LaunchCast<float16, double>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeInt8] = LaunchCast<float16, int8_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeInt16] = LaunchCast<float16, int16_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeInt32] = LaunchCast<float16, int32_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeInt64] = LaunchCast<float16, int64_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeUInt8] = LaunchCast<float16, uint8_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeUInt16] = LaunchCast<float16, uint16_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeUInt32] = LaunchCast<float16, uint32_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeUInt64] = LaunchCast<float16, uint64_t>; | |||
| mode_map[kNumberTypeFloat16][kNumberTypeBool] = LaunchCast<float16, bool>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeFloat16] = LaunchCast<float, float16>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeFloat32] = LaunchCast<float, float>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeFloat64] = LaunchCast<float, double>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeInt8] = LaunchCast<float, int8_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeInt16] = LaunchCast<float, int16_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeInt32] = LaunchCast<float, int32_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeInt64] = LaunchCast<float, int64_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeUInt8] = LaunchCast<float, uint8_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeUInt16] = LaunchCast<float, uint16_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeUInt32] = LaunchCast<float, uint32_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeUInt64] = LaunchCast<float, uint64_t>; | |||
| mode_map[kNumberTypeFloat32][kNumberTypeBool] = LaunchCast<float, bool>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeFloat16] = LaunchCast<double, float16>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeFloat32] = LaunchCast<double, float>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeFloat64] = LaunchCast<double, double>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeInt8] = LaunchCast<double, int8_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeInt16] = LaunchCast<double, int16_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeInt32] = LaunchCast<double, int32_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeInt64] = LaunchCast<double, int64_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeUInt8] = LaunchCast<double, uint8_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeUInt16] = LaunchCast<double, uint16_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeUInt32] = LaunchCast<double, uint32_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeUInt64] = LaunchCast<double, uint64_t>; | |||
| mode_map[kNumberTypeFloat64][kNumberTypeBool] = LaunchCast<double, bool>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeFloat16] = LaunchCast<int8_t, float16>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeFloat32] = LaunchCast<int8_t, float>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeFloat64] = LaunchCast<int8_t, double>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeInt8] = LaunchCast<int8_t, int8_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeInt16] = LaunchCast<int8_t, int16_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeInt32] = LaunchCast<int8_t, int32_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeInt64] = LaunchCast<int8_t, int64_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeUInt8] = LaunchCast<int8_t, uint8_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeUInt16] = LaunchCast<int8_t, uint16_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeUInt32] = LaunchCast<int8_t, uint32_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeUInt64] = LaunchCast<int8_t, uint64_t>; | |||
| mode_map[kNumberTypeInt8][kNumberTypeBool] = LaunchCast<int8_t, bool>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeFloat16] = LaunchCast<int16_t, float16>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeFloat32] = LaunchCast<int16_t, float>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeFloat64] = LaunchCast<int16_t, double>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeInt8] = LaunchCast<int16_t, int8_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeInt16] = LaunchCast<int16_t, int16_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeInt32] = LaunchCast<int16_t, int32_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeInt64] = LaunchCast<int16_t, int64_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeUInt8] = LaunchCast<int16_t, uint8_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeUInt16] = LaunchCast<int16_t, uint16_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeUInt32] = LaunchCast<int16_t, uint32_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeUInt64] = LaunchCast<int16_t, uint64_t>; | |||
| mode_map[kNumberTypeInt16][kNumberTypeBool] = LaunchCast<int16_t, bool>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeFloat16] = LaunchCast<int32_t, float16>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeFloat32] = LaunchCast<int32_t, float>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeFloat64] = LaunchCast<int32_t, double>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeInt8] = LaunchCast<int32_t, int8_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeInt16] = LaunchCast<int32_t, int16_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeInt32] = LaunchCast<int32_t, int32_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeInt64] = LaunchCast<int32_t, int64_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeUInt8] = LaunchCast<int32_t, uint8_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeUInt16] = LaunchCast<int32_t, uint16_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeUInt32] = LaunchCast<int32_t, uint32_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeUInt64] = LaunchCast<int32_t, uint64_t>; | |||
| mode_map[kNumberTypeInt32][kNumberTypeBool] = LaunchCast<int32_t, bool>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeFloat16] = LaunchCast<int64_t, float16>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeFloat32] = LaunchCast<int64_t, float>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeFloat64] = LaunchCast<int64_t, double>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeInt8] = LaunchCast<int64_t, int8_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeInt16] = LaunchCast<int64_t, int16_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeInt32] = LaunchCast<int64_t, int32_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeInt64] = LaunchCast<int64_t, int64_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeUInt8] = LaunchCast<int64_t, uint8_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeUInt16] = LaunchCast<int64_t, uint16_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeUInt32] = LaunchCast<int64_t, uint32_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeUInt64] = LaunchCast<int64_t, uint64_t>; | |||
| mode_map[kNumberTypeInt64][kNumberTypeBool] = LaunchCast<int64_t, bool>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeFloat16] = LaunchCast<uint8_t, float16>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeFloat32] = LaunchCast<uint8_t, float>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeFloat64] = LaunchCast<uint8_t, double>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeInt8] = LaunchCast<uint8_t, int8_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeInt16] = LaunchCast<uint8_t, int16_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeInt32] = LaunchCast<uint8_t, int32_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeInt64] = LaunchCast<uint8_t, int64_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeUInt8] = LaunchCast<uint8_t, uint8_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeUInt16] = LaunchCast<uint8_t, uint16_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeUInt32] = LaunchCast<uint8_t, uint32_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeUInt64] = LaunchCast<uint8_t, uint64_t>; | |||
| mode_map[kNumberTypeUInt8][kNumberTypeBool] = LaunchCast<uint8_t, bool>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeFloat16] = LaunchCast<uint16_t, float16>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeFloat32] = LaunchCast<uint16_t, float>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeFloat64] = LaunchCast<uint16_t, double>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeInt8] = LaunchCast<uint16_t, int8_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeInt16] = LaunchCast<uint16_t, int16_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeInt32] = LaunchCast<uint16_t, int32_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeInt64] = LaunchCast<uint16_t, int64_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeUInt8] = LaunchCast<uint16_t, uint8_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeUInt16] = LaunchCast<uint16_t, uint16_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeUInt32] = LaunchCast<uint16_t, uint32_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeUInt64] = LaunchCast<uint16_t, uint64_t>; | |||
| mode_map[kNumberTypeUInt16][kNumberTypeBool] = LaunchCast<uint16_t, bool>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeFloat16] = LaunchCast<uint32_t, float16>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeFloat32] = LaunchCast<uint32_t, float>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeFloat64] = LaunchCast<uint32_t, double>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeInt8] = LaunchCast<uint32_t, int8_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeInt16] = LaunchCast<uint32_t, int16_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeInt32] = LaunchCast<uint32_t, int32_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeInt64] = LaunchCast<uint32_t, int64_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeUInt8] = LaunchCast<uint32_t, uint8_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeUInt16] = LaunchCast<uint32_t, uint16_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeUInt32] = LaunchCast<uint32_t, uint32_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeUInt64] = LaunchCast<uint32_t, uint64_t>; | |||
| mode_map[kNumberTypeUInt32][kNumberTypeBool] = LaunchCast<uint32_t, bool>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeFloat16] = LaunchCast<uint64_t, float16>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeFloat32] = LaunchCast<uint64_t, float>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeFloat64] = LaunchCast<uint64_t, double>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeInt8] = LaunchCast<uint64_t, int8_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeInt16] = LaunchCast<uint64_t, int16_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeInt32] = LaunchCast<uint64_t, int32_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeInt64] = LaunchCast<uint64_t, int64_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeUInt8] = LaunchCast<uint64_t, uint8_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeUInt16] = LaunchCast<uint64_t, uint16_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeUInt32] = LaunchCast<uint64_t, uint32_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeUInt64] = LaunchCast<uint64_t, uint64_t>; | |||
| mode_map[kNumberTypeUInt64][kNumberTypeBool] = LaunchCast<uint64_t, bool>; | |||
| template <typename S, typename T> | |||
| bool CastCPUKernel<S, T>::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| S *input = reinterpret_cast<S *>(inputs[0]->addr); | |||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | |||
| MS_LOG(DEBUG) << "Type source: " << typeid(S).name() << "; target: " << typeid(T).name(); | |||
| mode_map[source_dtype][target_dtype](inputs, outputs); | |||
| size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1; | |||
| Cast<S, T>(input, output, lens); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| @@ -23,6 +23,7 @@ | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| template <typename S, typename T> | |||
| class CastCPUKernel : public CPUKernel { | |||
| public: | |||
| CastCPUKernel() = default; | |||
| @@ -38,161 +39,305 @@ class CastCPUKernel : public CPUKernel { | |||
| TypeId target_dtype{kTypeUnknown}; | |||
| }; | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel, | |||
| bool, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel, | |||
| bool, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel, | |||
| bool, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| bool, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| bool, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| bool, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| bool, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| bool, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, | |||
| bool, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, | |||
| bool, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, | |||
| bool, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| bool, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, float16, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, float16, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, float16, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| float16, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt16), | |||
| CastCPUKernel, float16, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt32), | |||
| CastCPUKernel, float16, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeInt64), | |||
| CastCPUKernel, float16, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), | |||
| CastCPUKernel, float16, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt16), | |||
| CastCPUKernel, float16, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt32), | |||
| CastCPUKernel, float16, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt64), | |||
| CastCPUKernel, float16, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| float16, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, float, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, float, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, float, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| float, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt16), | |||
| CastCPUKernel, float, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt32), | |||
| CastCPUKernel, float, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeInt64), | |||
| CastCPUKernel, float, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), | |||
| CastCPUKernel, float, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt16), | |||
| CastCPUKernel, float, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt32), | |||
| CastCPUKernel, float, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt64), | |||
| CastCPUKernel, float, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| float, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, double, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, double, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, double, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| double, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt16), | |||
| CastCPUKernel, double, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt32), | |||
| CastCPUKernel, double, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeInt64), | |||
| CastCPUKernel, double, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), | |||
| CastCPUKernel, double, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt16), | |||
| CastCPUKernel, double, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt32), | |||
| CastCPUKernel, double, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt64), | |||
| CastCPUKernel, double, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| double, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16), CastCPUKernel, | |||
| int8_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32), CastCPUKernel, | |||
| int8_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64), CastCPUKernel, | |||
| int8_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| int8_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| int8_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| int8_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| int8_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| int8_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, | |||
| int8_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, | |||
| int8_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, | |||
| int8_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| int8_t, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, int16_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, int16_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, int16_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| int16_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| int16_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| int16_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| int16_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| int16_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, | |||
| int16_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, | |||
| int16_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, | |||
| int16_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| int16_t, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, int32_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, int32_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, int32_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| int32_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| int32_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| int32_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| int32_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| int32_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, | |||
| int32_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, | |||
| int32_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, | |||
| int32_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| int32_t, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, int64_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, int64_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, int64_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| int64_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| int64_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| int64_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| int64_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| int64_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, | |||
| int64_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, | |||
| int64_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, | |||
| int64_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| int64_t, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, uint8_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, uint8_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, uint8_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| uint8_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| uint8_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| uint8_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| uint8_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| uint8_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt16), CastCPUKernel, | |||
| uint8_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt32), CastCPUKernel, | |||
| uint8_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt64), CastCPUKernel, | |||
| uint8_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| uint8_t, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, uint16_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, uint16_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, uint16_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| uint16_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| uint16_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| uint16_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| uint16_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| uint16_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||
| CastCPUKernel, uint16_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt32), | |||
| CastCPUKernel, uint16_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt64), | |||
| CastCPUKernel, uint16_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| uint16_t, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, uint32_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, uint32_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, uint32_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| uint32_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| uint32_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| uint32_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| uint32_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| uint32_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt16), | |||
| CastCPUKernel, uint32_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||
| CastCPUKernel, uint32_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt64), | |||
| CastCPUKernel, uint32_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| uint32_t, bool); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16), | |||
| CastCPUKernel, uint64_t, float16); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32), | |||
| CastCPUKernel, uint64_t, float); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64), | |||
| CastCPUKernel, uint64_t, double); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt8), CastCPUKernel, | |||
| uint64_t, int8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt16), CastCPUKernel, | |||
| uint64_t, int16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt32), CastCPUKernel, | |||
| uint64_t, int32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeInt64), CastCPUKernel, | |||
| uint64_t, int64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt8), CastCPUKernel, | |||
| uint64_t, uint8_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt16), | |||
| CastCPUKernel, uint64_t, uint16_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt32), | |||
| CastCPUKernel, uint64_t, uint32_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||
| CastCPUKernel, uint64_t, uint64_t); | |||
| MS_REG_CPU_KERNEL_T_S(Cast, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool), CastCPUKernel, | |||
| uint64_t, bool); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -14,8 +14,11 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include <cmath> | |||
| #include "backend/kernel_compiler/cpu/layer_norm_cpu_kernel.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "common/thread_pool.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -72,23 +75,43 @@ void LayerNormCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, con | |||
| auto y = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto mean = reinterpret_cast<T *>(outputs[1]->addr); | |||
| auto var = reinterpret_cast<T *>(outputs[2]->addr); | |||
| for (size_t i = 0; i < block_num_; ++i) { | |||
| T sum = (T)0.0; | |||
| T square_sum = (T)0.0; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| sum += x[j]; | |||
| square_sum += x[j] * x[j]; | |||
| } | |||
| T block_mean = sum / block_size_; | |||
| T block_var = square_sum / block_size_ - block_mean * block_mean; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast<double>(block_var) + eps_) * gamma[param_shift] + | |||
| beta[param_shift]; | |||
| size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| if (block_num_ < thread_num) { | |||
| thread_num = block_num_; | |||
| } | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num); ++c) { | |||
| if (c * thread_num + start >= block_num_) { | |||
| continue; | |||
| } | |||
| size_t i = c * thread_num + start; | |||
| T sum = (T)0.0; | |||
| T square_sum = (T)0.0; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| sum += x[j]; | |||
| square_sum += x[j] * x[j]; | |||
| } | |||
| T block_mean = sum / block_size_; | |||
| T block_var = square_sum / block_size_ - block_mean * block_mean; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| y[j] = (x[j] - block_mean) / (T)std::sqrt(static_cast<double>(block_var) + eps_) * gamma[param_shift] + | |||
| beta[param_shift]; | |||
| } | |||
| mean[i] = block_mean; | |||
| var[i] = block_var; | |||
| } | |||
| mean[i] = block_mean; | |||
| var[i] = block_var; | |||
| }; | |||
| for (size_t i = 0; i < thread_num; ++i) { | |||
| auto block = [&, i]() { | |||
| task(i, i + 1); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(block); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| } | |||
| void LayerNormCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| @@ -15,7 +15,9 @@ | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/layer_norm_grad_cpu_kernel.h" | |||
| #include "backend/kernel_compiler/common_utils.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "common/thread_pool.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| @@ -73,41 +75,75 @@ void LayerNormGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||
| auto dx = reinterpret_cast<T *>(outputs[0]->addr); | |||
| auto dg = reinterpret_cast<T *>(outputs[1]->addr); | |||
| auto db = reinterpret_cast<T *>(outputs[2]->addr); | |||
| for (size_t i = 0; i < param_num_; ++i) { | |||
| T dgamma = (T)0.0; | |||
| T dbeta = (T)0.0; | |||
| for (size_t j = i; j < param_size_ * param_num_; j += param_num_) { | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| dgamma += dy[j] * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]); | |||
| dbeta += dy[j]; | |||
| } | |||
| dg[i] = dgamma; | |||
| db[i] = dbeta; | |||
| } | |||
| for (size_t i = 0; i < block_num_; ++i) { | |||
| T sum1 = (T)0.0; | |||
| T sum2 = (T)0.0; | |||
| T sum3 = (T)0.0; | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| auto dxm = x[j] - mean[norm_shift]; | |||
| auto dyg = dy[j] * gamma[param_shift]; | |||
| sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -1.5); | |||
| sum2 += dyg; | |||
| sum3 += (T)(-2.0) * dxm; | |||
| size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| auto thread_num1 = param_num_ < thread_num ? param_num_ : thread_num; | |||
| std::vector<common::Task> tasks1; | |||
| tasks1.reserve(thread_num1); | |||
| auto thread_num2 = block_num_ < thread_num ? block_num_ : thread_num; | |||
| std::vector<common::Task> tasks2; | |||
| tasks2.reserve(thread_num2); | |||
| auto task1 = [&](size_t start, size_t end) { | |||
| for (size_t c = 0; c < ceil(static_cast<double>(param_num_) / thread_num1); ++c) { | |||
| if (c * thread_num1 + start >= param_num_) { | |||
| continue; | |||
| } | |||
| size_t param_index = c * thread_num1 + start; | |||
| T dgamma = (T)0.0; | |||
| T dbeta = (T)0.0; | |||
| for (size_t j = param_index; j < param_size_ * param_num_; j += param_num_) { | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| dgamma += dy[j] * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5) * (x[j] - mean[norm_shift]); | |||
| dbeta += dy[j]; | |||
| } | |||
| dg[param_index] = dgamma; | |||
| db[param_index] = dbeta; | |||
| } | |||
| for (size_t j = i * block_size_; j < (i + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| auto var_sqrt = (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5); | |||
| auto dx1 = dy[j] * gamma[param_shift] * var_sqrt; | |||
| auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]); | |||
| auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_); | |||
| dx[j] = dx1 + dx2 + dx3; | |||
| }; | |||
| auto task2 = [&](size_t start, size_t end) { | |||
| for (size_t c = 0; c < ceil(static_cast<double>(block_num_) / thread_num2); ++c) { | |||
| if (c * thread_num2 + start >= block_num_) { | |||
| continue; | |||
| } | |||
| size_t block_index = c * thread_num2 + start; | |||
| T sum1 = (T)0.0; | |||
| T sum2 = (T)0.0; | |||
| T sum3 = (T)0.0; | |||
| for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| auto dxm = x[j] - mean[norm_shift]; | |||
| auto dyg = dy[j] * gamma[param_shift]; | |||
| sum1 += (T)(-0.5) * dyg * dxm * (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -1.5); | |||
| sum2 += dyg; | |||
| sum3 += (T)(-2.0) * dxm; | |||
| } | |||
| for (size_t j = block_index * block_size_; j < (block_index + 1) * block_size_; ++j) { | |||
| auto param_shift = j % param_num_; | |||
| auto norm_shift = static_cast<int>(j / block_size_); | |||
| auto var_sqrt = (T)std::pow(static_cast<double>(var[norm_shift]) + eps_, -0.5); | |||
| auto dx1 = dy[j] * gamma[param_shift] * var_sqrt; | |||
| auto dx2 = sum1 * (T)2.0 / block_size_ * (x[j] - mean[norm_shift]); | |||
| auto dx3 = ((T)(-1.0) * var_sqrt * sum2 + ((T)1.0 / block_size_) * sum1 * sum3) * ((T)1.0 / block_size_); | |||
| dx[j] = dx1 + dx2 + dx3; | |||
| } | |||
| } | |||
| }; | |||
| for (size_t i = 0; i < thread_num1; ++i) { | |||
| auto block = [&, i]() { | |||
| task1(i, i + 1); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks1.emplace_back(block); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks1); | |||
| for (size_t i = 0; i < thread_num2; ++i) { | |||
| auto block = [&, i]() { | |||
| task2(i, i + 1); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks2.emplace_back(block); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks2); | |||
| } | |||
| void LayerNormGradCPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| @@ -16,6 +16,7 @@ | |||
| #include "backend/kernel_compiler/cpu/unsorted_segment_sum_cpu_kernel.h" | |||
| #include <string> | |||
| #include <cmath> | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| #include "common/thread_pool.h" | |||
| @@ -78,17 +79,37 @@ bool UnsortedSegmentSumCPUKernel::LaunchKernel(const std::vector<AddressPtr> &in | |||
| MS_LOG(ERROR) << "Output buff memset fail. ret:" << ret; | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < unit_num_; ++i) { | |||
| size_t j = i / input_dim1_; | |||
| size_t k = i % input_dim1_; | |||
| size_t thread_num = common::ThreadPool::GetInstance().GetSyncRunThreadNum(); | |||
| if (unit_num_ < thread_num) { | |||
| thread_num = unit_num_; | |||
| } | |||
| std::vector<common::Task> tasks; | |||
| tasks.reserve(thread_num); | |||
| auto task = [&](size_t start, size_t end) { | |||
| for (size_t c = 0; c < ceil(static_cast<double>(unit_num_) / thread_num); ++c) { | |||
| if (c * thread_num + start >= unit_num_) { | |||
| continue; | |||
| } | |||
| size_t i = c * thread_num + start; | |||
| size_t j = i / input_dim1_; | |||
| size_t k = i % input_dim1_; | |||
| T index = indices_addr[j]; | |||
| if (index < 0 || index >= SizeToInt(output_dim0_)) { | |||
| continue; | |||
| T index = indices_addr[j]; | |||
| if (index < 0 || index >= SizeToInt(output_dim0_)) { | |||
| continue; | |||
| } | |||
| size_t output_index = index * output_dim1_ + k; | |||
| output_addr[output_index] += input_addr[i]; | |||
| } | |||
| size_t output_index = index * output_dim1_ + k; | |||
| output_addr[output_index] += input_addr[i]; | |||
| }; | |||
| for (size_t t = 0; t < thread_num; ++t) { | |||
| auto block = [&, t]() { | |||
| task(t, t + 1); | |||
| return common::SUCCESS; | |||
| }; | |||
| tasks.emplace_back(block); | |||
| } | |||
| common::ThreadPool::GetInstance().SyncRun(tasks); | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||