|
|
|
@@ -28,30 +28,31 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); |
|
|
|
input_lens_ = 1; |
|
|
|
for (auto shape : input_shape_) { |
|
|
|
MS_LOG(DEBUG) << "input shape: " << shape; |
|
|
|
MS_LOG(INFO) << "input shape: " << shape; |
|
|
|
input_lens_ = input_lens_ * shape; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "input lens: " << input_lens_; |
|
|
|
MS_LOG(INFO) << "input lens: " << input_lens_; |
|
|
|
|
|
|
|
indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); |
|
|
|
indices_lens_ = 1; |
|
|
|
for (auto shape : indices_shape_) { |
|
|
|
MS_LOG(DEBUG) << "indice shape: " << shape; |
|
|
|
MS_LOG(INFO) << "indice shape: " << shape; |
|
|
|
indices_lens_ = indices_lens_ * shape; |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "indice lens: " << indices_lens_; |
|
|
|
MS_LOG(INFO) << "indice lens: " << indices_lens_; |
|
|
|
|
|
|
|
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); |
|
|
|
for (auto shape : output_shape_) { |
|
|
|
MS_LOG(DEBUG) << "output shape: " << shape; |
|
|
|
MS_LOG(INFO) << "output shape: " << shape; |
|
|
|
} |
|
|
|
auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0); |
|
|
|
MS_LOG(DEBUG) << "output type: " << output_type; |
|
|
|
MS_LOG(INFO) << "output type: " << output_type; |
|
|
|
|
|
|
|
axis_ = 4 - input_shape_.size(); |
|
|
|
MS_LOG(DEBUG) << "axis_: " << axis_; |
|
|
|
MS_LOG(INFO) << "axis_: " << axis_; |
|
|
|
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag"); |
|
|
|
MS_LOG(DEBUG) << "reduce_scatter_flag: " << reduce_scatter_flag_; |
|
|
|
MS_LOG(INFO) << "reduce_scatter_flag: " << reduce_scatter_flag_; |
|
|
|
#ifdef ENABLE_MPI |
|
|
|
if (reduce_scatter_flag_) { |
|
|
|
size_t gatherv2_out_lens = 1; |
|
|
|
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { |
|
|
|
@@ -66,7 +67,7 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
} |
|
|
|
} |
|
|
|
gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float); |
|
|
|
MS_LOG(DEBUG) << "gatherv2 out lens: " << gatherv2_out_lens_; |
|
|
|
MS_LOG(INFO) << "gatherv2 out lens: " << gatherv2_out_lens_; |
|
|
|
gather_v2_out_ = malloc(gatherv2_out_lens_); |
|
|
|
if (gather_v2_out_ == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; |
|
|
|
@@ -77,10 +78,15 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { |
|
|
|
} |
|
|
|
|
|
|
|
split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num"); |
|
|
|
MS_LOG(DEBUG) << "split_num: " << split_num_; |
|
|
|
MS_LOG(INFO) << "split_num: " << split_num_; |
|
|
|
} |
|
|
|
#else |
|
|
|
if (reduce_scatter_flag_) { |
|
|
|
MS_LOG(EXCEPTION) << "Not Enable MPI, please build version with -M on when set reduce_scatter_flag true"; |
|
|
|
} |
|
|
|
#endif |
|
|
|
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset"); |
|
|
|
MS_LOG(DEBUG) << "offset: " << offset_; |
|
|
|
MS_LOG(INFO) << "offset: " << offset_; |
|
|
|
CPUKernelUtils::ExpandDimsTo4(&input_shape_); |
|
|
|
CPUKernelUtils::ExpandDimsTo4(&output_shape_); |
|
|
|
} |
|
|
|
@@ -97,13 +103,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp |
|
|
|
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); |
|
|
|
MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << outputs[0]->size; |
|
|
|
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr; |
|
|
|
if (!reduce_scatter_flag_) { |
|
|
|
auto ret = memset_s(gather_out_addr, outputs[0]->size, 0, outputs[0]->size); |
|
|
|
if (ret != 0) { |
|
|
|
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset out buff failed"; |
|
|
|
} |
|
|
|
} |
|
|
|
MS_LOG(DEBUG) << "gatherv2 out addr: " << gather_out_addr; |
|
|
|
|
|
|
|
size_t dim0 = input_shape_[0]; |
|
|
|
size_t dim1 = input_shape_[1]; |
|
|
|
size_t dim2 = input_shape_[2]; |
|
|
|
@@ -130,6 +131,7 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp |
|
|
|
LookUpTable(inputs, 0, 0, 0, &gather_out_addr); |
|
|
|
} |
|
|
|
|
|
|
|
#ifdef ENABLE_MPI |
|
|
|
if (reduce_scatter_flag_) { |
|
|
|
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); |
|
|
|
size_t reduce_scatter_out_lens = one_split_lens / 8; |
|
|
|
@@ -140,6 +142,8 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp |
|
|
|
one_split_lens / 8, "sum"); |
|
|
|
} |
|
|
|
} |
|
|
|
#endif |
|
|
|
|
|
|
|
#if defined(_WIN32) || defined(_WIN64) |
|
|
|
auto end_time = std::chrono::steady_clock::now(); |
|
|
|
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time; |
|
|
|
@@ -153,67 +157,82 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
void memcpy_task(std::vector<float *> *mem_dest_addr_list, std::vector<float *> *mem_src_addr_list, size_t start, |
|
|
|
size_t end, size_t lens) { |
|
|
|
for (size_t i = start; i < end; i++) { |
|
|
|
auto ret = memcpy_s((*mem_dest_addr_list)[i], lens, (*mem_src_addr_list)[i], lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(EXCEPTION) << "memery copy failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
return; |
|
|
|
} |
|
|
|
|
|
|
|
void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, |
|
|
|
size_t dim2, float **output_addr) { |
|
|
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); |
|
|
|
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr); |
|
|
|
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); |
|
|
|
void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr, size_t indices_lens, size_t num, |
|
|
|
size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, std::vector<size_t> input_shape, |
|
|
|
size_t input_lens) { |
|
|
|
size_t lens = num * sizeof(float); |
|
|
|
std::vector<float *> mem_dest_addr_list; |
|
|
|
std::vector<float *> mem_src_addr_list; |
|
|
|
for (size_t i = 0; i < indices_lens_; ++i) { |
|
|
|
int indices = indices_addr[i] - offset_; |
|
|
|
for (size_t i = 0; i < indices_lens; ++i) { |
|
|
|
int indices = indices_addr[i] - offset; |
|
|
|
if (indices >= 0) { |
|
|
|
size_t index = IntToSize(indices); |
|
|
|
if (index < input_shape_[axis_]) { |
|
|
|
if (index < input_shape[axis]) { |
|
|
|
size_t pos = 0; |
|
|
|
if (axis_ == 3) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, dim2, index); |
|
|
|
} else if (axis_ == 2) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, dim1, index, 0); |
|
|
|
} else if (axis_ == 1) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape_, dim0, index, 0, 0); |
|
|
|
} else if (axis_ == 0) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape_, index, 0, 0, 0); |
|
|
|
if (axis == 3) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, dim2, index); |
|
|
|
} else if (axis == 2) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape, dim0, dim1, index, 0); |
|
|
|
} else if (axis == 1) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape, dim0, index, 0, 0); |
|
|
|
} else if (axis == 0) { |
|
|
|
pos = CPUKernelUtils::CalcOffset(input_shape, index, 0, 0, 0); |
|
|
|
} |
|
|
|
|
|
|
|
if (pos + num <= input_lens_) { |
|
|
|
mem_dest_addr_list.push_back(*output_addr); |
|
|
|
mem_src_addr_list.push_back(input_addr + pos); |
|
|
|
if (pos + num <= input_lens) { |
|
|
|
auto ret = memcpy_s(output_addr, lens, input_addr + pos, lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; |
|
|
|
} |
|
|
|
|
|
|
|
} else { |
|
|
|
auto ret = memset_s(output_addr, lens, 0, lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto ret = memset_s(output_addr, lens, 0, lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
} else { |
|
|
|
auto ret = memset_s(output_addr, lens, 0, lens); |
|
|
|
if (ret != EOK) { |
|
|
|
MS_LOG(EXCEPTION) << "LookUpTable task memset failed."; |
|
|
|
} |
|
|
|
} |
|
|
|
*output_addr += num; |
|
|
|
output_addr += num; |
|
|
|
} |
|
|
|
|
|
|
|
} |
|
|
|
void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, |
|
|
|
size_t dim2, float **output_addr) { |
|
|
|
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); |
|
|
|
auto indices_addr = reinterpret_cast<int *>(inputs[1]->addr); |
|
|
|
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_shape_, axis_); |
|
|
|
float *task_out_addr = *output_addr; |
|
|
|
const size_t thread_num = 8; |
|
|
|
std::thread threads[8]; |
|
|
|
size_t memcpy_lens = mem_dest_addr_list.size(); |
|
|
|
size_t start = 0; |
|
|
|
size_t ones_copy_lens = (memcpy_lens + thread_num - 1) / thread_num; |
|
|
|
size_t task_proc_lens = (indices_lens_ + thread_num - 1) / thread_num; |
|
|
|
size_t i; |
|
|
|
size_t task_offset = 0; |
|
|
|
MS_LOG(DEBUG) << "indices_lens_: " << indices_lens_ << " one task proc lens:" << task_proc_lens; |
|
|
|
for (i = 0; i < thread_num; i++) { |
|
|
|
if (start > memcpy_lens) { |
|
|
|
if (task_offset >= indices_lens_) { |
|
|
|
break; |
|
|
|
} |
|
|
|
auto end = (start + ones_copy_lens) > memcpy_lens ? memcpy_lens : start + ones_copy_lens; |
|
|
|
threads[i] = std::thread(memcpy_task, &mem_dest_addr_list, &mem_src_addr_list, start, end, lens); |
|
|
|
start = start + ones_copy_lens; |
|
|
|
MS_LOG(DEBUG) << "task_offset: " << task_offset << " task_proc_lenss:" << task_proc_lens; |
|
|
|
threads[i] = |
|
|
|
std::thread(LookUpTable_task, input_addr, task_out_addr + task_offset * num, indices_addr + task_offset, |
|
|
|
task_proc_lens, num, dim0, dim1, dim2, offset_, axis_, input_shape_, input_lens_); |
|
|
|
task_offset += task_proc_lens; |
|
|
|
if (task_offset + task_proc_lens > indices_lens_) { |
|
|
|
task_proc_lens = indices_lens_ - task_offset; |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t j = 0; j < i; j++) { |
|
|
|
threads[j].join(); |
|
|
|
} |
|
|
|
*output_addr += num * indices_lens_; |
|
|
|
} |
|
|
|
|
|
|
|
void EmbeddingLookUpCPUKernel::CheckParam(const CNodePtr &kernel_node) { |
|
|
|
|