| @@ -29,8 +29,7 @@ void EluGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| } | } | ||||
| } | } | ||||
| bool EluGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| const std::vector<kernel::AddressPtr> & /*workspace*/, | |||||
| bool EluGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||||
| const std::vector<kernel::AddressPtr> &outputs) { | const std::vector<kernel::AddressPtr> &outputs) { | ||||
| if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { | if (dtype_ == kNumberTypeFloat32 || dtype_ == kNumberTypeFloat) { | ||||
| LaunchKernel<float>(inputs, outputs); | LaunchKernel<float>(inputs, outputs); | ||||
| @@ -43,7 +42,8 @@ bool EluGradCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||||
| } | } | ||||
| template <typename T> | template <typename T> | ||||
| void EluGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) { | |||||
| void EluGradCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, | |||||
| const std::vector<AddressPtr> &outputs) const { | |||||
| T *input0 = reinterpret_cast<T *>(inputs[0]->addr); | T *input0 = reinterpret_cast<T *>(inputs[0]->addr); | ||||
| T *input1 = reinterpret_cast<T *>(inputs[1]->addr); | T *input1 = reinterpret_cast<T *>(inputs[1]->addr); | ||||
| T *output = reinterpret_cast<T *>(outputs[0]->addr); | T *output = reinterpret_cast<T *>(outputs[0]->addr); | ||||
| @@ -32,7 +32,7 @@ class EluGradCPUKernel : public CPUKernel { | |||||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | ||||
| const std::vector<AddressPtr> &outputs) override; | const std::vector<AddressPtr> &outputs) override; | ||||
| template <typename T> | template <typename T> | ||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||||
| void LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs) const; | |||||
| private: | private: | ||||
| TypeId dtype_{kTypeUnknown}; | TypeId dtype_{kTypeUnknown}; | ||||
| @@ -43,9 +43,12 @@ bool SelectCPUKernel<T>::Launch(const std::vector<AddressPtr> &inputs, const std | |||||
| auto input_x = reinterpret_cast<T *>(inputs[1]->addr); | auto input_x = reinterpret_cast<T *>(inputs[1]->addr); | ||||
| auto input_y = reinterpret_cast<T *>(inputs[2]->addr); | auto input_y = reinterpret_cast<T *>(inputs[2]->addr); | ||||
| auto output = reinterpret_cast<T *>(outputs[0]->addr); | auto output = reinterpret_cast<T *>(outputs[0]->addr); | ||||
| for (size_t pos = 0; pos < element_num_; pos++) { | |||||
| output[pos] = input_cond[pos] ? input_x[pos] : input_y[pos]; | |||||
| } | |||||
| auto task = [=](const size_t start, const size_t end) { | |||||
| for (size_t pos = start; pos < end; pos++) { | |||||
| output[pos] = input_cond[pos] ? input_x[pos] : input_y[pos]; | |||||
| } | |||||
| }; | |||||
| CPUKernelUtils::ParallelFor(task, element_num_); | |||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace kernel | } // namespace kernel | ||||