| @@ -31,14 +31,15 @@ using mindspore::schema::PrimitiveType_OneHot; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| namespace { | namespace { | ||||
| constexpr size_t kInputNum = 4; | constexpr size_t kInputNum = 4; | ||||
| constexpr size_t kInputNumOpt = 3; | |||||
| constexpr size_t kOutputNum = 1; | constexpr size_t kOutputNum = 1; | ||||
| } // namespace | } // namespace | ||||
| int OneHotCPUKernel::Init() { | int OneHotCPUKernel::Init() { | ||||
| // indices depth on_value off_value | // indices depth on_value off_value | ||||
| if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { | |||||
| MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size() | |||||
| << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); | |||||
| if ((in_tensors_.size() != kInputNum && in_tensors_.size() != kInputNumOpt) || out_tensors_.size() != kOutputNum) { | |||||
| MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << " or " << kInputNumOpt << ", got " | |||||
| << in_tensors_.size() << ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (context_ == nullptr) { | if (context_ == nullptr) { | ||||
| @@ -132,27 +133,42 @@ int OneHotCPUKernel::GetParams() { | |||||
| } | } | ||||
| one_hot_param->depth_ = *depth; | one_hot_param->depth_ = *depth; | ||||
| auto on_value_tensor = in_tensors_.at(2); | |||||
| if (on_value_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const float *on_value = static_cast<float *>(on_value_tensor->MutableData()); | |||||
| if (on_value == nullptr) { | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| one_hot_param->on_value_ = *on_value; | |||||
| auto off_value_tensor = in_tensors_.at(3); | |||||
| if (off_value_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const float *off_value = static_cast<float *>(off_value_tensor->MutableData()); | |||||
| if (off_value == nullptr) { | |||||
| return RET_NULL_PTR; | |||||
| if (in_tensors_.size() == kInputNum) { | |||||
| auto on_value_tensor = in_tensors_.at(2); | |||||
| if (on_value_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const float *on_value = static_cast<float *>(on_value_tensor->MutableData()); | |||||
| if (on_value == nullptr) { | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| one_hot_param->on_value_ = *on_value; | |||||
| auto off_value_tensor = in_tensors_.at(3); | |||||
| if (off_value_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot inputs[3] off_value nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const float *off_value = static_cast<float *>(off_value_tensor->MutableData()); | |||||
| if (off_value == nullptr) { | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| one_hot_param->off_value_ = *off_value; | |||||
| } else { | |||||
| auto off_on_tensor = in_tensors_.at(2); | |||||
| if (off_on_tensor == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot inputs[2] on_value nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| const int64_t *off_on_values = static_cast<int64_t *>(off_on_tensor->MutableData()); | |||||
| if (off_on_values == nullptr) { | |||||
| MS_LOG(ERROR) << "OneHot input[2] data is nullptr"; | |||||
| return RET_NULL_PTR; | |||||
| } | |||||
| one_hot_param->off_value_ = static_cast<float>(off_on_values[0]); | |||||
| one_hot_param->on_value_ = static_cast<float>(off_on_values[1]); | |||||
| } | } | ||||
| one_hot_param->off_value_ = *off_value; | |||||
| one_hot_param->outer_size_ = outer_size_; | one_hot_param->outer_size_ = outer_size_; | ||||
| one_hot_param->inner_size_ = inner_size_; | one_hot_param->inner_size_ = inner_size_; | ||||