Browse Source

fix review code

tags/v0.5.0-beta
kswang 5 years ago
parent
commit
2a8f0a75be
8 changed files with 62 additions and 115 deletions
  1. +35
    -34
      mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc
  2. +1
    -16
      mindspore/ccsrc/device/ascend/ascend_stream_assign.cc
  3. +0
    -5
      mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc
  4. +1
    -1
      mindspore/ccsrc/device/cpu/cpu_device_address.cc
  5. +18
    -10
      mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc
  6. +3
    -2
      mindspore/ccsrc/device/kernel_runtime_manager.cc
  7. +2
    -45
      mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc
  8. +2
    -2
      mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc

+ 35
- 34
mindspore/ccsrc/device/ascend/ascend_kernel_runtime.cc View File

@@ -52,6 +52,38 @@ namespace mindspore {
namespace device { namespace device {
namespace ascend { namespace ascend {
static const size_t PRAMATER_OUTPUT_INDEX = 0; static const size_t PRAMATER_OUTPUT_INDEX = 0;
namespace {
std::string GetRankId() {
std::string rank_id_str;
#ifdef ENABLE_MPI
auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (mpi_config_ptr->enable_mpi()) {
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) {
try {
int rank_offset = std::stoi(offset);
rank_id += rank_offset;
} catch (std::invalid_argument) {
MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset;
} catch (std::out_of_range) {
MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset;
}
}
rank_id_str = std::to_string(rank_id);
} else {
rank_id_str = std::getenv("RANK_ID");
}
#else
rank_id_str = std::getenv("RANK_ID");
#endif
if (rank_id_str.empty()) {
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
}
return rank_id_str;
}
} // namespace


AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); } AscendKernelRuntime::~AscendKernelRuntime() { graph_model_map_.clear(); }


@@ -497,7 +529,6 @@ bool AscendKernelRuntime::HcclInit() {
if (!context_ptr->IsTsdOpened()) { if (!context_ptr->IsTsdOpened()) {
MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open"; MS_LOG(EXCEPTION) << "Hccl dependent tsd is not open";
} }

MS_LOG(INFO) << "do hcom init"; MS_LOG(INFO) << "do hcom init";
auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH"); auto config_path_str = std::getenv("MINDSPORE_HCCL_CONFIG_PATH");
if (config_path_str == nullptr) { if (config_path_str == nullptr) {
@@ -507,44 +538,14 @@ bool AscendKernelRuntime::HcclInit() {
} }
return false; return false;
} }
std::string rank_id_str = GetRankId();
auto full_path = realpath(config_path_str, nullptr); auto full_path = realpath(config_path_str, nullptr);
if (full_path == nullptr) { if (full_path == nullptr) {
MS_LOG(ERROR) << "file path " << config_path_str << " does not exist"; MS_LOG(ERROR) << "file path " << config_path_str << " does not exist";
return false; return false;
} }
const char *identify = nullptr;
#ifdef ENABLE_MPI
std::string rank_id_tmp;
auto mpi_config_ptr = MpiConfig::GetInstance();
MS_EXCEPTION_IF_NULL(mpi_config_ptr);
if (mpi_config_ptr->enable_mpi()) {
int rank_id = device::cpu::MPIAdapter::Instance().GetRankId();
const char *offset = std::getenv("RANK_OFFSET");
if (offset != nullptr) {
try {
int rank_offset = std::stoi(offset);
rank_id += rank_offset;
} catch (std::invalid_argument) {
MS_LOG(EXCEPTION) << "stoi invalid argument:" << offset;
} catch (std::out_of_range) {
MS_LOG(EXCEPTION) << "stoi out_of_range:" << offset;
}
}
rank_id_tmp = std::to_string(rank_id);
identify = rank_id_tmp.c_str();
} else {
identify = std::getenv("RANK_ID");
}
#else
identify = std::getenv("RANK_ID");
#endif
if (identify == nullptr) {
MS_LOG(ERROR) << "get hccl rankid failed, please set env RANK_ID";
free(full_path);
return false;
}
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << identify;
hcclResult_t res = hcom_init(full_path, identify);
MS_LOG(INFO) << "MINDSPORE_HCCL_CONFIG_PATH : " << full_path << ", RANK_ID: " << rank_id_str;
hcclResult_t res = hcom_init(full_path, rank_id_str.c_str());
free(full_path); free(full_path);
if (res != HCCL_SUCCESS) { if (res != HCCL_SUCCESS) {
MS_LOG(ERROR) << "hcom init failed, res is " << static_cast<int>(res); MS_LOG(ERROR) << "hcom init failed, res is " << static_cast<int>(res);


+ 1
- 16
mindspore/ccsrc/device/ascend/ascend_stream_assign.cc View File

@@ -303,15 +303,12 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
fusion_hcom_index.emplace_back(i); fusion_hcom_index.emplace_back(i);
} }
} }

if (fusion_hcom_index.size() < 2) { if (fusion_hcom_index.size() < 2) {
MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them"; MS_LOG(INFO) << "fusion hcom size is less than 2, no need insert event between them";
return; return;
} }

uint32_t first_index = fusion_hcom_index[0]; uint32_t first_index = fusion_hcom_index[0];
uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1]; uint32_t last_index = fusion_hcom_index[fusion_hcom_index.size() - 1];

uint32_t cur_event_id = total_event_num_; uint32_t cur_event_id = total_event_num_;
uint32_t pre_hcom_stream_id = UINT32_MAX; uint32_t pre_hcom_stream_id = UINT32_MAX;
std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders)); std::copy(cnode_ptr_list.begin(), cnode_ptr_list.begin() + first_index, std::back_inserter(orders));
@@ -322,13 +319,11 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
orders.emplace_back(cur_cnode); orders.emplace_back(cur_cnode);
continue; continue;
} }

auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode); auto cur_hcom_stream_id = AnfAlgo::GetStreamId(cur_cnode);
if (cur_hcom_stream_id == pre_hcom_stream_id) { if (cur_hcom_stream_id == pre_hcom_stream_id) {
orders.emplace_back(cur_cnode); orders.emplace_back(cur_cnode);
continue; continue;
} }

if (i == first_index) { if (i == first_index) {
// first fusion hcom // first fusion hcom
orders.emplace_back(cur_cnode); orders.emplace_back(cur_cnode);
@@ -348,15 +343,12 @@ void AscendStreamAssign::InsertSendRecvForDiffHcom(const shared_ptr<mindspore::s
auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id); auto send = CreateSendApplyKernel(graph_ptr, cur_event_id, cur_hcom_stream_id);
orders.emplace_back(send); orders.emplace_back(send);
} }

pre_hcom_stream_id = cur_hcom_stream_id; pre_hcom_stream_id = cur_hcom_stream_id;
} }

std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders)); std::copy(cnode_ptr_list.begin() + last_index + 1, cnode_ptr_list.end(), std::back_inserter(orders));
graph_ptr->set_execution_order(orders); graph_ptr->set_execution_order(orders);
total_event_num_ = cur_event_id; total_event_num_ = cur_event_id;
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]";
MS_LOG(INFO) << "end";
MS_LOG(INFO) << "after indsert between allreduce, total event nums[" << total_event_num_ << "]\n end";
} }


void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) { void AscendStreamAssign::InsertSendRecvForHcomParallel(const shared_ptr<mindspore::session::KernelGraph> &graph_ptr) {
@@ -826,7 +818,6 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
std::vector<CNodePtr> exe_orders; std::vector<CNodePtr> exe_orders;
std::vector<CNodePtr> independents; std::vector<CNodePtr> independents;
std::vector<CNodePtr> others; std::vector<CNodePtr> others;

auto cnode_ptr_list = graph_ptr->execution_order(); auto cnode_ptr_list = graph_ptr->execution_order();
MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size(); MS_LOG(INFO) << "before reorder, graph orders size:" << cnode_ptr_list.size();
for (size_t i = 0; i < cnode_ptr_list.size(); ++i) { for (size_t i = 0; i < cnode_ptr_list.size(); ++i) {
@@ -838,19 +829,16 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
others.emplace_back(cur_cnode_ptr); others.emplace_back(cur_cnode_ptr);
} }
} }

if (others.empty()) { if (others.empty()) {
std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders)); std::copy(independents.begin(), independents.end(), std::back_inserter(exe_orders));
graph_ptr->set_execution_order(exe_orders); graph_ptr->set_execution_order(exe_orders);
return; return;
} }

if (independents.empty()) { if (independents.empty()) {
std::copy(others.begin(), others.end(), std::back_inserter(exe_orders)); std::copy(others.begin(), others.end(), std::back_inserter(exe_orders));
graph_ptr->set_execution_order(exe_orders); graph_ptr->set_execution_order(exe_orders);
return; return;
} }

std::vector<CNodePtr> processed; std::vector<CNodePtr> processed;
for (size_t i = 0; i < others.size(); i++) { for (size_t i = 0; i < others.size(); i++) {
auto begin = others.begin() + i; auto begin = others.begin() + i;
@@ -862,7 +850,6 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
if (it != processed.end()) { if (it != processed.end()) {
continue; continue;
} }

auto res = FindTargetOp(begin, end, cur_independent); auto res = FindTargetOp(begin, end, cur_independent);
if (res != end) { if (res != end) {
flag = true; flag = true;
@@ -872,12 +859,10 @@ void AscendStreamAssign::ReorderIndependentOrders(const shared_ptr<mindspore::se
break; break;
} }
} }

if (!flag) { if (!flag) {
exe_orders.emplace_back(*begin); exe_orders.emplace_back(*begin);
} }
} }

MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size(); MS_LOG(INFO) << "after reorder, graph orders size:" << exe_orders.size();
graph_ptr->set_execution_order(exe_orders); graph_ptr->set_execution_order(exe_orders);
} }


+ 0
- 5
mindspore/ccsrc/device/ascend/profiling/profiling_manager.cc View File

@@ -121,7 +121,6 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
MS_LOG(ERROR) << "Register profiling Engine failed."; MS_LOG(ERROR) << "Register profiling Engine failed.";
return false; return false;
} }

auto context = MsContext::GetInstance(); auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context); MS_EXCEPTION_IF_NULL(context);
const string prof_options_str = context->profiling_options(); const string prof_options_str = context->profiling_options();
@@ -130,7 +129,6 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!"; MS_LOG(WARNING) << "Profiling is enabled, but profiling option is not set!";
return true; return true;
} }

// current one docker only use one device` // current one docker only use one device`
Json p_device; Json p_device;
// JOBID // JOBID
@@ -149,7 +147,6 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
// only one device, but sProfMgrStartUp API require for device list // only one device, but sProfMgrStartUp API require for device list
Json devices; Json devices;
devices[0] = p_device; devices[0] = p_device;

Json startCfg; Json startCfg;
startCfg["startCfg"] = devices; startCfg["startCfg"] = devices;


@@ -157,9 +154,7 @@ bool ProfilingManager::StartupProfiling(uint32_t device_id) {
std::stringstream ss; std::stringstream ss;
ss << startCfg; ss << startCfg;
std::string cfg = ss.str(); std::string cfg = ss.str();

MS_LOG(INFO) << "profiling config " << cfg; MS_LOG(INFO) << "profiling config " << cfg;

auto ret = rtProfilerStart(); auto ret = rtProfilerStart();
if (ret != RT_ERROR_NONE) { if (ret != RT_ERROR_NONE) {
MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret; MS_LOG(INFO) << "Call rtProfilerStart failed, ret:" << ret;


+ 1
- 1
mindspore/ccsrc/device/cpu/cpu_device_address.cc View File

@@ -33,7 +33,7 @@ bool CPUDeviceAddress::SyncDeviceToHost(const std::vector<int> & /*shape*/, size
} }


if (type == type_id_) { if (type == type_id_) {
auto ret_code = memcpy_s(host_ptr, size, ptr_, size);
auto ret_code = memcpy_s(host_ptr, size, ptr_, size_);
if (ret_code != EOK) { if (ret_code != EOK) {
MS_LOG(ERROR) << "Failed to copy tensor!"; MS_LOG(ERROR) << "Failed to copy tensor!";
return false; return false;


+ 18
- 10
mindspore/ccsrc/device/cpu/cpu_kernel_runtime.cc View File

@@ -34,6 +34,23 @@ namespace mindspore {
namespace device { namespace device {
namespace cpu { namespace cpu {
const size_t INIT_NODE_REF = 1; const size_t INIT_NODE_REF = 1;
namespace {
TypeId GetCPUSupportOutputTypeId(const TypeId type_id) {
TypeId support_type_id = type_id;
if (type_id == kNumberTypeUInt32) {
support_type_id = kNumberTypeInt32;
}
if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 ||
type_id == kNumberTypeFloat64) {
support_type_id = kNumberTypeFloat32;
}
if (support_type_id != kNumberTypeInt32 && support_type_id != kNumberTypeFloat32) {
MS_LOG(EXCEPTION) << "Check output type failed.";
}
return support_type_id;
}
} // namespace

void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) { void CPUKernelRuntime::AssignKernelAddress(session::KernelGraph *kernel_graph) {
AssignValueNodeAddress(kernel_graph); AssignValueNodeAddress(kernel_graph);
AssignInputNodeAddress(kernel_graph); AssignInputNodeAddress(kernel_graph);
@@ -149,16 +166,7 @@ BaseRef CPUKernelRuntime::CreatTensorForOutput(const AnfNodePtr &input_node, siz
std::vector<int> temp_shape; std::vector<int> temp_shape;
(void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end()); (void)temp_shape.insert(temp_shape.end(), shape.begin(), shape.end());
TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index); TypeId type_id = AnfAlgo::GetOutputInferDataType(node, index);
if (type_id == kNumberTypeUInt32) {
type_id = kNumberTypeInt32;
}
if (type_id == kNumberTypeFloat || type_id == kNumberTypeFloat16 || type_id == kNumberTypeFloat32 ||
type_id == kNumberTypeFloat64) {
type_id = kNumberTypeFloat32;
}
if (type_id != kNumberTypeInt32 && type_id != kNumberTypeFloat32) {
MS_LOG(EXCEPTION) << "Check output type failed.";
}
type_id = GetCPUSupportOutputTypeId(type_id);
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape);
MS_EXCEPTION_IF_NULL(tensor); MS_EXCEPTION_IF_NULL(tensor);
if (address->ref_count_ > 0 && address->ptr_ != nullptr) { if (address->ref_count_ > 0 && address->ptr_ != nullptr) {


+ 3
- 2
mindspore/ccsrc/device/kernel_runtime_manager.cc View File

@@ -54,8 +54,9 @@ KernelRuntime *KernelRuntimeManager::GetSingleKernelRuntime(const std::string &d
return runtime_iter->second.get(); return runtime_iter->second.get();
} else if (runtime_map_.size() > 0) { } else if (runtime_map_.size() > 0) {
auto cur_runtime_key = runtime_map_.begin()->first; auto cur_runtime_key = runtime_map_.begin()->first;
if (cur_runtime_key.rfind('_') != std::string::npos) {
auto cur_device_id = cur_runtime_key.substr(cur_runtime_key.rfind('_') + 1);
auto find_pos = cur_runtime_key.rfind('_');
if (find_pos != std::string::npos) {
auto cur_device_id = cur_runtime_key.substr(find_pos + 1);
MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id MS_LOG(EXCEPTION) << "Can't change device id in runtime, already set device id: " << cur_device_id
<< ", set device id: " << device_id << " failed"; << ", set device id: " << device_id << " failed";
} }


+ 2
- 45
mindspore/ccsrc/kernel/cpu/embedding_look_up_cpu_kernel.cc View File

@@ -24,50 +24,32 @@ namespace mindspore {
namespace kernel { namespace kernel {
void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) { void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
CheckParam(kernel_node); CheckParam(kernel_node);

input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input_lens_ = 1; input_lens_ = 1;
for (auto shape : input_shape_) { for (auto shape : input_shape_) {
MS_LOG(INFO) << "input shape: " << shape;
input_lens_ = input_lens_ * shape; input_lens_ = input_lens_ * shape;
} }
MS_LOG(INFO) << "input lens: " << input_lens_;

indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); indices_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
indices_lens_ = 1; indices_lens_ = 1;
for (auto shape : indices_shape_) { for (auto shape : indices_shape_) {
MS_LOG(INFO) << "indice shape: " << shape;
indices_lens_ = indices_lens_ * shape; indices_lens_ = indices_lens_ * shape;
} }
MS_LOG(INFO) << "indice lens: " << indices_lens_;

output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0); output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (auto shape : output_shape_) {
MS_LOG(INFO) << "output shape: " << shape;
}
auto output_type = AnfAlgo::GetOutputInferDataType(kernel_node, 0);
MS_LOG(INFO) << "output type: " << output_type;

axis_ = 4 - input_shape_.size(); axis_ = 4 - input_shape_.size();
MS_LOG(INFO) << "axis_: " << axis_;
reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag"); reduce_scatter_flag_ = AnfAlgo::GetNodeAttr<bool>(kernel_node, "reduce_scatter_flag");
MS_LOG(INFO) << "reduce_scatter_flag: " << reduce_scatter_flag_;
#ifdef ENABLE_MPI #ifdef ENABLE_MPI
if (reduce_scatter_flag_) { if (reduce_scatter_flag_) {
size_t gatherv2_out_lens = 1; size_t gatherv2_out_lens = 1;
for (int i = 0; i < SizeToInt(input_shape_.size()); i++) { for (int i = 0; i < SizeToInt(input_shape_.size()); i++) {
if (i == 0) { if (i == 0) {
for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) { for (int j = 0; j < SizeToInt(indices_shape_.size()); j++) {
MS_LOG(DEBUG) << "gatherv2 out shape: " << indices_shape_[j];
gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j]; gatherv2_out_lens = gatherv2_out_lens * indices_shape_[j];
} }
} else { } else {
MS_LOG(DEBUG) << "gatherv2 out shape: " << input_shape_[i];
gatherv2_out_lens = gatherv2_out_lens * input_shape_[i]; gatherv2_out_lens = gatherv2_out_lens * input_shape_[i];
} }
} }
gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float); gatherv2_out_lens_ = gatherv2_out_lens * sizeof(float);
MS_LOG(INFO) << "gatherv2 out lens: " << gatherv2_out_lens_;
gather_v2_out_ = malloc(gatherv2_out_lens_); gather_v2_out_ = malloc(gatherv2_out_lens_);
if (gather_v2_out_ == nullptr) { if (gather_v2_out_ == nullptr) {
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_; MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel malloc failed, malloc lens: " << gatherv2_out_lens_;
@@ -76,9 +58,7 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
if (ret != 0) { if (ret != 0) {
MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed"; MS_LOG(EXCEPTION) << "EmbeddingLookUpCPUKernel memset gatherv2 out buff failed";
} }

split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num"); split_num_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "split_num");
MS_LOG(INFO) << "split_num: " << split_num_;
} }
#else #else
if (reduce_scatter_flag_) { if (reduce_scatter_flag_) {
@@ -86,7 +66,6 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
#endif #endif
offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset"); offset_ = AnfAlgo::GetNodeAttr<int>(kernel_node, "offset");
MS_LOG(INFO) << "offset: " << offset_;
CPUKernelUtils::ExpandDimsTo4(&input_shape_); CPUKernelUtils::ExpandDimsTo4(&input_shape_);
CPUKernelUtils::ExpandDimsTo4(&output_shape_); CPUKernelUtils::ExpandDimsTo4(&output_shape_);
} }
@@ -94,21 +73,11 @@ void EmbeddingLookUpCPUKernel::InitKernel(const CNodePtr &kernel_node) {
bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs, bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/, const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) { const std::vector<kernel::AddressPtr> &outputs) {
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr); auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
MS_LOG(DEBUG) << "output addr: " << output_addr << "output size: " << outputs[0]->size;
float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr; float *gather_out_addr = reduce_scatter_flag_ ? reinterpret_cast<float *>(gather_v2_out_) : output_addr;
MS_LOG(DEBUG) << "gatherv2 out addr: " << gather_out_addr;

size_t dim0 = input_shape_[0]; size_t dim0 = input_shape_[0];
size_t dim1 = input_shape_[1]; size_t dim1 = input_shape_[1];
size_t dim2 = input_shape_[2]; size_t dim2 = input_shape_[2];

if (axis_ == 3) { if (axis_ == 3) {
for (size_t i = 0; i < dim0; ++i) { for (size_t i = 0; i < dim0; ++i) {
for (size_t j = 0; j < dim1; ++j) { for (size_t j = 0; j < dim1; ++j) {
@@ -130,7 +99,6 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
} else if (axis_ == 0) { } else if (axis_ == 0) {
LookUpTable(inputs, 0, 0, 0, &gather_out_addr); LookUpTable(inputs, 0, 0, 0, &gather_out_addr);
} }

#ifdef ENABLE_MPI #ifdef ENABLE_MPI
if (reduce_scatter_flag_) { if (reduce_scatter_flag_) {
size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float); size_t one_split_lens = gatherv2_out_lens_ / split_num_ / sizeof(float);
@@ -143,21 +111,10 @@ bool EmbeddingLookUpCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inp
} }
} }
#endif #endif

#if defined(_WIN32) || defined(_WIN64)
auto end_time = std::chrono::steady_clock::now();
std::chrono::duration<double, std::ratio<1, 1000000>> cost = end_time - start_time;
MS_LOG(INFO) << "EmbeddingLookUpCPUKernel, used time: " << cost.count() << " us";
#else
(void)gettimeofday(&end_time, nullptr);
uint64_t time = 1000000 * static_cast<uint64_t>(end_time.tv_sec - start_time.tv_sec);
time += static_cast<uint64_t>(end_time.tv_usec - start_time.tv_usec);
MS_LOG(INFO) << "EmbeddingLookUpCPUKernel, used time: " << time << " us";
#endif
return true; return true;
} }


void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr, size_t indices_lens, size_t num,
void LookUpTable_task(const float *input_addr, float *output_addr, int *indices_addr, size_t indices_lens, size_t num,
size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, std::vector<size_t> input_shape, size_t dim0, size_t dim1, size_t dim2, int offset, size_t axis, std::vector<size_t> input_shape,
size_t input_lens) { size_t input_lens) {
size_t lens = num * sizeof(float); size_t lens = num * sizeof(float);
@@ -182,7 +139,6 @@ void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr,
if (ret != EOK) { if (ret != EOK) {
MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed."; MS_LOG(EXCEPTION) << "LookUpTable task memcpy failed.";
} }

} else { } else {
auto ret = memset_s(output_addr, lens, 0, lens); auto ret = memset_s(output_addr, lens, 0, lens);
if (ret != EOK) { if (ret != EOK) {
@@ -204,6 +160,7 @@ void LookUpTable_task(float *input_addr, float *output_addr, int *indices_addr,
output_addr += num; output_addr += num;
} }
} }

void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, void EmbeddingLookUpCPUKernel::LookUpTable(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
size_t dim2, float **output_addr) { size_t dim2, float **output_addr) {
auto input_addr = reinterpret_cast<float *>(inputs[0]->addr); auto input_addr = reinterpret_cast<float *>(inputs[0]->addr);


+ 2
- 2
mindspore/ccsrc/kernel/cpu/sub_cpu_kernel.cc View File

@@ -30,7 +30,7 @@ void SubCPUKernel::InitKernel(const CNodePtr &kernel_node) {
} }
} }


void sub_task(int *in_addr, int *out_addr, size_t lens, int offset) {
void sub_task(const int *in_addr, int *out_addr, size_t lens, int offset) {
for (size_t i = 0; i < lens; i++) { for (size_t i = 0; i < lens; i++) {
out_addr[i] = in_addr[i] - offset; out_addr[i] = in_addr[i] - offset;
} }
@@ -55,7 +55,7 @@ bool SubCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
output_addr[i] = input_addr[i] - offset_; output_addr[i] = input_addr[i] - offset_;
} }
} else { } else {
size_t thread_num = 4;
const size_t thread_num = 4;
std::thread threads[4]; std::thread threads[4];
size_t process_lens = (lens + thread_num - 1) / thread_num; size_t process_lens = (lens + thread_num - 1) / thread_num;
size_t process_offset = 0; size_t process_offset = 0;


Loading…
Cancel
Save