|
|
|
@@ -456,9 +456,7 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) { |
|
|
|
} |
|
|
|
|
|
|
|
int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_; |
|
|
|
if (src_a.data_type_ == LDataType::BOOL) { |
|
|
|
DivideImpl<bool>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::INT8) { |
|
|
|
if (src_a.data_type_ == LDataType::INT8) { |
|
|
|
DivideImpl<int8_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::UINT8) { |
|
|
|
DivideImpl<uint8_t>(src_a, src_b, *dst, total_size); |
|
|
|
@@ -484,5 +482,102 @@ bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) { |
|
|
|
for (size_t i = 0; i < total_size; i++) { |
|
|
|
dst[i] = src0[i] * src1[i]; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) { |
|
|
|
int64_t x = 0; |
|
|
|
#ifdef USE_NEON |
|
|
|
const int64_t step = 32; |
|
|
|
for (; x <= total_size - step; x += step) { |
|
|
|
uint8x16_t v_src00 = vld1q_u8(src0 + x); |
|
|
|
uint8x16_t v_src01 = vld1q_u8(src0 + x + 16); |
|
|
|
uint8x16_t v_src10 = vld1q_u8(src1 + x); |
|
|
|
uint8x16_t v_src11 = vld1q_u8(src1 + x + 16); |
|
|
|
uint8x16_t v_dst_l, v_dst_h; |
|
|
|
|
|
|
|
v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10)); |
|
|
|
v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10)); |
|
|
|
vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h))); |
|
|
|
|
|
|
|
v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11)); |
|
|
|
v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11)); |
|
|
|
vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h))); |
|
|
|
} |
|
|
|
#endif |
|
|
|
for (; x < total_size; x++) { |
|
|
|
int32_t val = src0[x] * src1[x]; |
|
|
|
dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(), |
|
|
|
std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) { |
|
|
|
for (size_t i = 0; i < total_size; i++) { |
|
|
|
int32_t val = src0[i] * src1[i]; |
|
|
|
dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(), |
|
|
|
std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) { |
|
|
|
for (size_t i = 0; i < total_size; i++) { |
|
|
|
int64_t val = src0[i] * src1[i]; |
|
|
|
dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(), |
|
|
|
std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val)); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) { |
|
|
|
if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (src_a.data_type_ != src_b.data_type_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
if (dst->IsEmpty()) { |
|
|
|
dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_); |
|
|
|
} else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) { |
|
|
|
return false; |
|
|
|
} else if (src_a.data_type_ != dst->data_type_) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_; |
|
|
|
if (src_a.data_type_ == LDataType::INT8) { |
|
|
|
MultiplyImpl<int8_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::UINT8) { |
|
|
|
MultiplyImpl<uint8_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::INT16) { |
|
|
|
MultiplyImpl<int16_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::UINT16) { |
|
|
|
MultiplyImpl<uint16_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::INT32) { |
|
|
|
MultiplyImpl<int32_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::UINT32) { |
|
|
|
MultiplyImpl<uint32_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::INT64) { |
|
|
|
MultiplyImpl<int64_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::UINT64) { |
|
|
|
MultiplyImpl<uint64_t>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::FLOAT32) { |
|
|
|
MultiplyImpl<float>(src_a, src_b, *dst, total_size); |
|
|
|
} else if (src_a.data_type_ == LDataType::FLOAT64) { |
|
|
|
MultiplyImpl<double>(src_a, src_b, *dst, total_size); |
|
|
|
} else { |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace dataset |
|
|
|
} // namespace mindspore |