|
|
|
@@ -35,40 +35,49 @@ constexpr size_t kOutputNum = 1; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
int OneHotCPUKernel::Init() { |
|
|
|
if (context_->infer_shape_interrupt_ && !context_->running_) { |
|
|
|
set_need_reinit(); |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
// indices depth on_value off_value |
|
|
|
if (in_tensors_.size() != kInputNum || out_tensors_.size() != kOutputNum) { |
|
|
|
MS_LOG(ERROR) << "OneHot input size should be " << kInputNum << ", got " << in_tensors_.size() |
|
|
|
<< ", output size should be" << kOutputNum << ", got " << out_tensors_.size(); |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
if (context_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "OneHot context nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
thread_num_ = context_->thread_num_; |
|
|
|
|
|
|
|
auto param = reinterpret_cast<OneHotParameter *>(op_parameter_); |
|
|
|
if (param == nullptr) { |
|
|
|
MS_LOG(ERROR) << "OneHot op_parameter_ nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
axis_ = param->axis_; |
|
|
|
|
|
|
|
if (!InferShapeDone()) { |
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
return ReSize(); |
|
|
|
} |
|
|
|
|
|
|
|
int OneHotCPUKernel::ReSize() { |
|
|
|
auto indices = in_tensors_.at(0); |
|
|
|
if (indices == nullptr) { |
|
|
|
MS_LOG(ERROR) << "OneHot inputs[0] indices nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
auto indices_shape = indices->shape(); |
|
|
|
const int indices_rank = static_cast<int>(indices_shape.size()); |
|
|
|
if (axis_ < 0) { |
|
|
|
axis_ += indices_rank + 1; |
|
|
|
} |
|
|
|
|
|
|
|
outer_size_ = 1; |
|
|
|
for (size_t i = 0; i < static_cast<size_t>(axis_); i++) { |
|
|
|
outer_size_ *= indices_shape[i]; |
|
|
|
} |
|
|
|
inner_size_ = indices->ElementsNum() / outer_size_; |
|
|
|
|
|
|
|
if (context_ == nullptr) { |
|
|
|
MS_LOG(ERROR) << "OneHot context nullptr"; |
|
|
|
return RET_NULL_PTR; |
|
|
|
} |
|
|
|
thread_num_ = context_->thread_num_; |
|
|
|
|
|
|
|
const int indices_rank = static_cast<int>(in_tensors_.at(0)->shape().size()); |
|
|
|
if (axis_ < 0) { |
|
|
|
axis_ += indices_rank + 1; |
|
|
|
} |
|
|
|
|
|
|
|
return RET_OK; |
|
|
|
} |
|
|
|
|
|
|
|
|