|
|
@@ -31,7 +31,7 @@ namespace mindspore::kernel { |
|
|
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, |
|
|
int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, |
|
|
int out_thread_stride) { |
|
|
int out_thread_stride) { |
|
|
if (dim > break_pos_) { |
|
|
if (dim > break_pos_) { |
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt) { |
|
|
|
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) { |
|
|
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride, |
|
|
return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride, |
|
|
reinterpret_cast<int *>(input1) + out_thread_stride, |
|
|
reinterpret_cast<int *>(input1) + out_thread_stride, |
|
|
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); |
|
|
reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); |
|
|
@@ -44,7 +44,7 @@ int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *o |
|
|
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i; |
|
|
int pos0_ = param_->in_shape0_[dim] == 1 ? 0 : i; |
|
|
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; |
|
|
int pos1_ = param_->in_shape1_[dim] == 1 ? 0 : i; |
|
|
int error_code; |
|
|
int error_code; |
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt) { |
|
|
|
|
|
|
|
|
if (in_tensors_[0]->data_type() == kNumberTypeInt || in_tensors_[0]->data_type() == kNumberTypeInt32) { |
|
|
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim], |
|
|
error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * param_->in_strides0_[dim], |
|
|
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim], |
|
|
reinterpret_cast<int *>(input1) + pos1_ * param_->in_strides1_[dim], |
|
|
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count, |
|
|
reinterpret_cast<uint8_t *>(output) + i * param_->out_strides_[dim], dim + 1, out_count, |
|
|
|