|
|
|
@@ -58,6 +58,10 @@ int ArithmeticCompareCPUKernel::Init() { |
|
|
|
int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; } |
|
|
|
|
|
|
|
int ArithmeticCompareCPUKernel::DoExecute(int task_id) { |
|
|
|
if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) { |
|
|
|
MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! "; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
int elements_num = in_tensors_.at(0)->ElementsNum(); |
|
|
|
int stride = UP_DIV(elements_num, op_parameter_->thread_num_); |
|
|
|
int offset = task_id * stride; |
|
|
|
|