Browse Source

fix some problems for power

tags/v1.1.0
Pengyongrong 5 years ago
parent
commit
d97bd37f57
1 changed files with 3 additions and 7 deletions
  1. +3
    -7
      mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc

+ 3
- 7
mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc View File

@@ -32,16 +32,10 @@ using mindspore::schema::PrimitiveType_Power;
namespace mindspore::kernel {

int PowerOpenCLKernel::CheckSpecs() {
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
broadcast_ = param->broadcast_;
if ((in_tensors_.size() != 1 && in_tensors_.size() != 2) || out_tensors_.size() != 1) {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << "out size: " << out_tensors_.size();
return RET_ERROR;
}
if (in_tensors_.size() == 1 && !broadcast_) {
MS_LOG(ERROR) << "broadcast is supported when in_tensors_.size() == 1 ";
return RET_ERROR;
}
if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) {
MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size()
<< "!=" << in_tensors_.at(1)->shape().size();
@@ -143,12 +137,14 @@ void PowerOpenCLKernel::SetGlobalLocal() {
}

int PowerOpenCLKernel::Prepare() {
if (in_tensors_.size() == 1) {
broadcast_ = true;
}
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
std::string kernel_name = "power";
std::string source = power_source;
std::string program_name = "power";
broadcast_ = param->broadcast_;
if (broadcast_ && in_tensors_.size() == 1) {
power_ = param->power_;
kernel_name += "_broadcast";


Loading…
Cancel
Save