Browse Source

!9009 fix onehot 3 inputs convert fail

From: @zhaozhenlong
Reviewed-by: @zhang_xue_tong,@zhanghaibo5
Signed-off-by: @zhang_xue_tong
tags/v1.1.0
mindspore-ci-bot Gitee 5 years ago
parent
commit
67265bf677
4 changed files with 13 additions and 7 deletions
  1. +6
    -3
      mindspore/lite/src/ops/one_hot.cc
  2. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc
  3. +5
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/squeeze_fp32.cc
  4. +1
    -1
      mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc

+ 6
- 3
mindspore/lite/src/ops/one_hot.cc View File

@@ -82,7 +82,8 @@ Registry OneHotRegistry(schema::PrimitiveType_OneHot, OneHotCreator);

namespace {
constexpr size_t kOneHotInputNum = 4;
}
constexpr size_t kOneHotInputNumOpt = 3;
} // namespace
int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outputs) {
if (this->primitive_ == nullptr) {
return RET_NULL_PTR;
@@ -90,8 +91,10 @@ int OneHot::InferShape(std::vector<Tensor *> inputs, std::vector<Tensor *> outpu

int axis = GetAxis();
// indices, depth, on_value, off_value
if (inputs.size() != kOneHotInputNum) {
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum;
// indices, depth, on_off_value(contain 2 values);
if (inputs.size() != kOneHotInputNum && inputs.size() != kOneHotInputNumOpt) {
MS_LOG(ERROR) << "OneHot got inputs num " << inputs.size() << ", should be " << kOneHotInputNum << " or "
<< kOneHotInputNumOpt;
return RET_ERROR;
}
auto depth_tensor = inputs.at(1);


+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/fp32/space_to_batch_fp32.cc View File

@@ -43,7 +43,7 @@ int SpaceToBatchCPUKernel::ReSize() {
MS_ASSERT(input_tensor);
auto output_tensor = out_tensors_.at(0);
MS_ASSERT(output_tensor);
MS_ASSERT(param);
MS_ASSERT(param_);
for (size_t i = 0; i < DIMENSION_4D; i++) {
param_->input_shape_[i] = input_tensor->shape().at(i);
param_->output_shape_[i] = output_tensor->shape().at(i);


+ 5
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/squeeze_fp32.cc View File

@@ -34,15 +34,18 @@ int SqueezeCPUKernel::ReSize() { return RET_OK; }
int SqueezeCPUKernel::Run() {
mindspore::lite::STATUS ret = RET_ERROR;
size_t data_size = in_tensors_.front()->Size();
MS_ASSERT(input_ptr);
MS_ASSERT(output_ptr);

if (in_tensors_.front()->data_type() == kNumberTypeInt32) {
auto input_ptr = reinterpret_cast<int32_t *>(in_tensors_.front()->MutableData());
auto output_ptr = reinterpret_cast<int32_t *>(out_tensors_.front()->MutableData());
MS_ASSERT(input_ptr);
MS_ASSERT(output_ptr);
ret = DoSqueezeInt32(input_ptr, output_ptr, data_size);
} else {
auto input_ptr = reinterpret_cast<float *>(in_tensors_.front()->MutableData());
auto output_ptr = reinterpret_cast<float *>(out_tensors_.front()->MutableData());
MS_ASSERT(input_ptr);
MS_ASSERT(output_ptr);
ret = DoSqueeze(input_ptr, output_ptr, data_size);
}



+ 1
- 1
mindspore/lite/src/runtime/kernel/arm/int8/squeeze_int8.cc View File

@@ -61,7 +61,7 @@ int SqueezeInt8CPUKernel::Init() {
return RET_ERROR;
}
auto in_quant_args = in_tensors_.front()->quant_params();
MS_ASSERT(quant_args.size() > 0);
MS_ASSERT(in_quant_args.size() > 0);
quant_squeeze_param_->in_quant_args_->scale_ = in_quant_args.front().scale;
quant_squeeze_param_->in_quant_args_->zp_ = in_quant_args.front().zeroPoint;



Loading…
Cancel
Save