| @@ -32,16 +32,10 @@ using mindspore::schema::PrimitiveType_Power; | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| int PowerOpenCLKernel::CheckSpecs() { | int PowerOpenCLKernel::CheckSpecs() { | ||||
| auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_); | |||||
| broadcast_ = param->broadcast_; | |||||
| if ((in_tensors_.size() != 1 && in_tensors_.size() != 2) || out_tensors_.size() != 1) { | if ((in_tensors_.size() != 1 && in_tensors_.size() != 2) || out_tensors_.size() != 1) { | ||||
| MS_LOG(ERROR) << "in size: " << in_tensors_.size() << "out size: " << out_tensors_.size(); | MS_LOG(ERROR) << "in size: " << in_tensors_.size() << "out size: " << out_tensors_.size(); | ||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| if (in_tensors_.size() == 1 && !broadcast_) { | |||||
| MS_LOG(ERROR) << "broadcast is supported when in_tensors_.size() == 1 "; | |||||
| return RET_ERROR; | |||||
| } | |||||
| if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) { | if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) { | ||||
| MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size() | MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size() | ||||
| << "!=" << in_tensors_.at(1)->shape().size(); | << "!=" << in_tensors_.at(1)->shape().size(); | ||||
| @@ -143,12 +137,14 @@ void PowerOpenCLKernel::SetGlobalLocal() { | |||||
| } | } | ||||
| int PowerOpenCLKernel::Prepare() { | int PowerOpenCLKernel::Prepare() { | ||||
| if (in_tensors_.size() == 1) { | |||||
| broadcast_ = true; | |||||
| } | |||||
| use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); | use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); | ||||
| auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_); | auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_); | ||||
| std::string kernel_name = "power"; | std::string kernel_name = "power"; | ||||
| std::string source = power_source; | std::string source = power_source; | ||||
| std::string program_name = "power"; | std::string program_name = "power"; | ||||
| broadcast_ = param->broadcast_; | |||||
| if (broadcast_ && in_tensors_.size() == 1) { | if (broadcast_ && in_tensors_.size() == 1) { | ||||
| power_ = param->power_; | power_ = param->power_; | ||||
| kernel_name += "_broadcast"; | kernel_name += "_broadcast"; | ||||