| @@ -23,19 +23,20 @@ namespace mindspore { | |||||
| namespace kernel { | namespace kernel { | ||||
| namespace { | namespace { | ||||
| template <typename T> | template <typename T> | ||||
| void LookUpTableTask(const float *input_addr, const T *indices_addr, const float *output_max_addr, float *output_addr, | |||||
| size_t indices_lens, size_t outer_dim_size, T offset, size_t first_dim_size) { | |||||
| size_t lens = outer_dim_size * sizeof(float); | |||||
| void LookUpTableTask(const float *input_addr, const T *indices_addr, float *output_addr, size_t indices_lens, | |||||
| size_t outer_dim_size, T offset, size_t first_dim_size) { | |||||
| auto type_size = sizeof(float); | |||||
| size_t lens = outer_dim_size * type_size; | |||||
| for (size_t i = 0; i < indices_lens; ++i) { | for (size_t i = 0; i < indices_lens; ++i) { | ||||
| T index = indices_addr[i] - offset; | T index = indices_addr[i] - offset; | ||||
| if (index >= 0 && index < SizeToInt(first_dim_size)) { | if (index >= 0 && index < SizeToInt(first_dim_size)) { | ||||
| size_t pos = index * outer_dim_size; | size_t pos = index * outer_dim_size; | ||||
| auto ret = memcpy_s(output_addr, output_max_addr - output_addr, input_addr + pos, lens); | |||||
| auto ret = memcpy_s(output_addr, (indices_lens - i) * lens, input_addr + pos, lens); | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; | ||||
| } | } | ||||
| } else { | } else { | ||||
| auto ret = memset_s(output_addr, output_max_addr - output_addr, 0, lens); | |||||
| auto ret = memset_s(output_addr, (indices_lens - i) * lens, 0, lens); | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; | MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; | ||||
| } | } | ||||
| @@ -82,7 +83,7 @@ void EmbeddingLookUpCPUKernel::LaunchKernel(const std::vector<kernel::AddressPtr | |||||
| break; | break; | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; | MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; | ||||
| threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset, output_addr + outputs[0]->size, | |||||
| threads[i] = std::thread(LookUpTableTask<T>, input_addr, indices_addr + task_offset, | |||||
| output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_, | output_addr + task_offset * outer_dim_size_, task_proc_lens, outer_dim_size_, offset_, | ||||
| first_dim_size_); | first_dim_size_); | ||||
| task_offset += task_proc_lens; | task_offset += task_proc_lens; | ||||