From: @zhaozhenlong Reviewed-by: @zhang_xue_tong,@zhanghaibo5 Signed-off-by: @zhang_xue_tongtags/v1.1.0
| @@ -82,7 +82,8 @@ Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator); | |||||
| namespace { | namespace { | ||||
| constexpr size_t kOneHotInputNum = 4; | constexpr size_t kOneHotInputNum = 4; | ||||
| } | |||||
| constexpr size_t kOneHotInputNumOpt = 3; | |||||
| } // namespace | |||||
| int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) { | ||||
| if (this->primitive_ == nullptr) { | if (this->primitive_ == nullptr) { | ||||
| return RET_NULL_PTR; | return RET_NULL_PTR; | ||||
| @@ -90,8 +91,10 @@ int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outpu | |||||
| int axis = GetAxis(); | int axis = GetAxis(); | ||||
| // indices, depth, on_value, off_value | // indices, depth, on_value, off_value | ||||
| if (inputs.size() != kOneHotInputNum) { | |||||
| MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum; | |||||
| // indices, depth, on_off_value(contain 2 values); | |||||
| if (inputs.size() != kOneHotInputNum && inputs.size() != kOneHotInputNumOpt) { | |||||
| MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum << " or " | |||||
| << kOneHotInputNumOpt; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto depth_tensor = inputs.at(1); | auto depth_tensor = inputs.at(1); | ||||
| @@ -43,7 +43,7 @@ int SpaceToBatchCPUKernel::ReSize() { | |||||
| MS_ASSERT(input_tensor); | MS_ASSERT(input_tensor); | ||||
| auto output_tensor = out_tensors_.at(0); | auto output_tensor = out_tensors_.at(0); | ||||
| MS_ASSERT(output_tensor); | MS_ASSERT(output_tensor); | ||||
| MS_ASSERT(param); | |||||
| MS_ASSERT(param_); | |||||
| for (size_t i = 0; i < DIMENSION_4D; i++) { | for (size_t i = 0; i < DIMENSION_4D; i++) { | ||||
| param_->input_shape_[i] = input_tensor->shape().at(i); | param_->input_shape_[i] = input_tensor->shape().at(i); | ||||
| param_->output_shape_[i] = output_tensor->shape().at(i); | param_->output_shape_[i] = output_tensor->shape().at(i); | ||||
| @@ -34,15 +34,18 @@ int SqueezeCPUKernel::ReSize() { return RET_OK; } | |||||
| int SqueezeCPUKernel::Run() { | int SqueezeCPUKernel::Run() { | ||||
| mindspore::lite::STATUS ret = RET_ERROR; | mindspore::lite::STATUS ret = RET_ERROR; | ||||
| size_t data_size = in_tensors_.front()->Size(); | size_t data_size = in_tensors_.front()->Size(); | ||||
| MS_ASSERT(input_ptr); | |||||
| MS_ASSERT(output_ptr); | |||||
| if (in_tensors_.front()->data_type() == kNumberTypeInt32) { | if (in_tensors_.front()->data_type() == kNumberTypeInt32) { | ||||
| auto input_ptr = reinterpret_cast<int32_t *>(in_tensors_.front()->MutableData()); | auto input_ptr = reinterpret_cast<int32_t *>(in_tensors_.front()->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData()); | auto output_ptr = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData()); | ||||
| MS_ASSERT(input_ptr); | |||||
| MS_ASSERT(output_ptr); | |||||
| ret = DoSqueezeInt32(input_ptr, output_ptr, data_size); | ret = DoSqueezeInt32(input_ptr, output_ptr, data_size); | ||||
| } else { | } else { | ||||
| auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->MutableData()); | auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->MutableData()); | ||||
| auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->MutableData()); | auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->MutableData()); | ||||
| MS_ASSERT(input_ptr); | |||||
| MS_ASSERT(output_ptr); | |||||
| ret = DoSqueeze(input_ptr, output_ptr, data_size); | ret = DoSqueeze(input_ptr, output_ptr, data_size); | ||||
| } | } | ||||
| @@ -61,7 +61,7 @@ int SqueezeInt8CPUKernel::Init() { | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| auto in_quant_args = in_tensors_.front()->quant_params(); | auto in_quant_args = in_tensors_.front()->quant_params(); | ||||
| MS_ASSERT(quant_args.size() > 0); | |||||
| MS_ASSERT(in_quant_args.size() > 0); | |||||
| quant_squeeze_param_->in_quant_args_->scale_ = in_quant_args.front().scale; | quant_squeeze_param_->in_quant_args_->scale_ = in_quant_args.front().scale; | ||||
| quant_squeeze_param_->in_quant_args_->zp_ = in_quant_args.front().zeroPoint; | quant_squeeze_param_->in_quant_args_->zp_ = in_quant_args.front().zeroPoint; | ||||