Browse Source

Support exponent tensor broadcast for power op

tags/v1.0.0
zhanyuan 5 years ago
parent
commit
afcb3e9b45
2 changed files with 5 additions and 3 deletions
  1. +3
    -1
      mindspore/lite/src/ops/power.cc
  2. +2
    -2
      mindspore/lite/src/runtime/kernel/arm/fp32/power.cc

+ 3
- 1
mindspore/lite/src/ops/power.cc View File

@@ -64,7 +64,9 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
return RET_OK;
}
if (exp_tensor != nullptr) {
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
if ((exp_tensor->shape().size() > 1 && exp_tensor->shape() != x_tensor->shape()) ||
(exp_tensor->shape().size() == 1 && exp_tensor->shape()[0] != 1) ||
exp_tensor->data_type() != x_tensor->data_type()) {
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
return RET_INPUT_TENSOR_ERROR;
}


+ 2
- 2
mindspore/lite/src/runtime/kernel/arm/fp32/power.cc View File

@@ -64,11 +64,11 @@ int PowerCPUKernel::RunImpl(int task_id) {
bool broadcast = true;
if (in_tensors_.size() == 2) {
exp_addr = reinterpret_cast<float *>(in_tensors_[1]->Data());
broadcast = false;
broadcast = in_tensors_[0]->shape() == in_tensors_[1]->shape() ? false : true;
}
float *cur_exp = nullptr;
if (broadcast) {
cur_exp = &power_;
cur_exp = in_tensors_.size() == 2 ? exp_addr : &power_;
} else {
cur_exp = exp_addr + stride * task_id;
}


Loading…
Cancel
Save