Browse Source

fix lite train resize issue

feature/build-system-rewrite
zhengjun10 4 years ago
parent
commit
4855e427d6
2 changed files with 6 additions and 1 deletions
  1. +5
    -0
      mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c
  2. +1
    -1
      mindspore/lite/test/config/models_ms_train.cfg

+ 5
- 0
mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/common_infer.c View File

@@ -385,6 +385,11 @@ int CommonGradInferShape(const TensorC *const *inputs, size_t inputs_size, Tenso
return NNACL_INFER_INVALID;
}
MS_CHECK_TRUE_RET(inputs[0]->shape_size_ == inputs[1]->shape_size_, NNACL_ERR);
for (int i = 0; i < inputs[0]->shape_size_; i++) {
if (inputs[0]->shape_[i] != inputs[1]->shape_[i]) {
return NNACL_ERR;
}
}
SetShapeTensor(outputs[0], inputs[0]);
return NNACL_OK;
}


+ 1
- 1
mindspore/lite/test/config/models_ms_train.cfg View File

@@ -45,4 +45,4 @@ unified_api code_example
train_lenet code_example
train_lenet_java code_example
# LAST
test_resize inputShapes 16,10,10,1:16,10,10,1 0.5
#test_resize inputShapes 16,10,10,1:16,10,10,1 0.5

Loading…
Cancel
Save