| @@ -21,17 +21,14 @@ namespace mindspore { | |||
| namespace kernel { | |||
| void GatherV2CPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||
| CheckParam(kernel_node); | |||
| input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); | |||
| output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); | |||
| axis_ = AnfAlgo::GetNodeAttr<int>(kernel_node, AXIS); | |||
| if (axis_ < 0) { | |||
| axis_ = axis_ + SizeToInt(input_shape_.size()); | |||
| } | |||
| axis_ += 4 - input_shape_.size(); | |||
| CPUKernelUtils::ExpandDimsTo4(&input_shape_); | |||
| CPUKernelUtils::ExpandDimsTo4(&output_shape_); | |||
| } | |||
| @@ -44,7 +41,6 @@ bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| size_t dim0 = input_shape_[0]; | |||
| size_t dim1 = input_shape_[1]; | |||
| size_t dim2 = input_shape_[2]; | |||
| if (axis_ == 3) { | |||
| for (size_t i = 0; i < dim0; ++i) { | |||
| for (size_t j = 0; j < dim1; ++j) { | |||
| @@ -66,7 +62,6 @@ bool GatherV2CPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, | |||
| } else if (axis_ == 0) { | |||
| CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -75,34 +70,43 @@ void GatherV2CPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> & | |||
| auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); | |||
| auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr); | |||
| size_t elem_num = inputs[1]->size / 4; | |||
| size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); | |||
| for (size_t i = 0; i < elem_num; ++i) { | |||
| size_t index = IntToSize(indices_addr[i]); | |||
| size_t pos = 0; | |||
| if (axis_ == 3) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); | |||
| } else if (axis_ == 2) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); | |||
| } else if (axis_ == 1) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); | |||
| } else if (axis_ == 0) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); | |||
| if (indices_addr[i] < 0) { | |||
| MS_LOG(EXCEPTION) << "The indices value is less than 0."; | |||
| } | |||
| size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); | |||
| auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "memcpy failed."; | |||
| size_t index = IntToSize(indices_addr[i]); | |||
| if (index >= input_shape_[IntToSize(axis_)]) { | |||
| auto ret = memset_s(*output_addr, *buff_size, 0., num * sizeof(float)); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "memset failed."; | |||
| } | |||
| } else { | |||
| size_t pos = 0; | |||
| if (axis_ == 3) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); | |||
| } else if (axis_ == 2) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); | |||
| } else if (axis_ == 1) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); | |||
| } else if (axis_ == 0) { | |||
| pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); | |||
| } | |||
| auto ret = memcpy_s(*output_addr, *buff_size, input_addr + pos, num * sizeof(float)); | |||
| if (ret != EOK) { | |||
| MS_LOG(EXCEPTION) << "memcpy failed."; | |||
| } | |||
| } | |||
| *output_addr += num; | |||
| *buff_size -= num * sizeof(float); | |||
| } | |||
| } | |||
| } // namespace kernel | |||
| void GatherV2CPUKernel::CheckParam(const CNodePtr &kernel_node) { | |||
| auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); | |||
| if (input_shape.size() > 4) { | |||
| MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but GatherV2CPUKernel olny support 4d or lower."; | |||
| } | |||
| size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); | |||
| if (input_num != 2) { | |||
| MS_LOG(EXCEPTION) << "Argument number is " << input_num << ", but GatherV2CPUKernel needs 2."; | |||