Browse Source

[CPU]fix accuracy problem for CheckValid

tags/v1.6.0
lianghao23 4 years ago
parent
commit
605e90e2d7
1 changed files with 13 additions and 7 deletions
  1. +13
    -7
      mindspore/ccsrc/backend/kernel_compiler/cpu/check_valid_cpu_kernel.cc

+ 13
- 7
mindspore/ccsrc/backend/kernel_compiler/cpu/check_valid_cpu_kernel.cc View File

@@ -24,6 +24,9 @@ namespace kernel {
namespace {
constexpr size_t kInputSize = 2;
constexpr size_t kOutputSize = 1;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
} // namespace

template <typename T>
@@ -45,12 +48,15 @@ bool CheckValidCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &input
auto output = reinterpret_cast<bool *>(outputs[0]->addr);
const size_t elem_num = inputs[0]->size / sizeof(T) / COORDINATE;

auto task = [this, &anchor_box, &img_metas, &output](size_t start, size_t end) {
const double offset = 1.0;
auto height = static_cast<double>(img_metas[kIndex0]);
auto width = static_cast<double>(img_metas[kIndex1]);
auto ratio = static_cast<double>(img_metas[kIndex2]);
auto img_width_x = width * ratio - offset;
auto img_height_y = height * ratio - offset;

auto task = [this, &anchor_box, &img_width_x, &img_height_y, &output](size_t start, size_t end) {
const T ZERO = static_cast<T>(0);
const T ONE = static_cast<T>(1);
constexpr size_t OFFSET_ZERO = 0;
constexpr size_t OFFSET_ONE = 1;
constexpr size_t OFFSET_TWO = 2;
for (size_t i = start; i < end; i++) {
const size_t left_x = i * 4;
const size_t left_y = i * 4 + 1;
@@ -60,8 +66,8 @@ bool CheckValidCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &input
bool valid_flag = false;
valid_flag |= std::less<T>()(anchor_box[left_x], ZERO);
valid_flag |= std::less<T>()(anchor_box[left_y], ZERO);
valid_flag |= std::less<T>()(img_metas[OFFSET_ONE] * img_metas[OFFSET_TWO] - ONE, anchor_box[right_x]);
valid_flag |= std::less<T>()(img_metas[OFFSET_ZERO] * img_metas[OFFSET_TWO] - ONE, anchor_box[right_y]);
valid_flag |= std::less<double>()(img_width_x, static_cast<double>(anchor_box[right_x]));
valid_flag |= std::less<double>()(img_height_y, static_cast<double>(anchor_box[right_y]));

output[i] = !valid_flag;
}


Loading…
Cancel
Save