Browse Source

fix gpu shape bug

pull/14118/head
yeyunpeng2020 4 years ago
parent
commit
12ea873ee0
3 changed files with 8 additions and 6 deletions
  1. +5
    -6
      mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc
  2. +2
    -0
      mindspore/lite/test/models_gpu_fp32.cfg
  3. +1
    -0
      mindspore/lite/tools/cropper/build_cropper_config.sh

+ 5
- 6
mindspore/lite/src/runtime/kernel/opencl/kernel/fill.cc View File

@@ -55,13 +55,12 @@ int FillOpenCLKernel::RunShape() {
auto tensor_shape = in_tensors_[0]->shape(); auto tensor_shape = in_tensors_[0]->shape();
void *tensor_shape_data = tensor_shape.data(); void *tensor_shape_data = tensor_shape.data();
for (int i = 0; i < tensor_shape.size(); ++i) { for (int i = 0; i < tensor_shape.size(); ++i) {
fill_value.s[0] = reinterpret_cast<float *>(tensor_shape_data)[i];
size_t index = static_cast<size_t>(i);
auto src_origin = cl::array<cl::size_type, 3U>{0, index, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
fill_value.s[i] = reinterpret_cast<float *>(tensor_shape_data)[i];
} }
auto src_origin = cl::array<cl::size_type, 3U>{0, 0, 0};
auto region = cl::array<cl::size_type, 3U>{1, 1, 1};
cl::Image2D *out_image = reinterpret_cast<cl::Image2D *>(allocator_->GetImage(src_data));
ocl_runtime_->GetDefaultCommandQueue()->enqueueFillImage(*out_image, fill_value, src_origin, region);
return RET_OK; return RET_OK;
} }




+ 2
- 0
mindspore/lite/test/models_gpu_fp32.cfg View File

@@ -23,3 +23,5 @@ landmark
PoseNet_dla_17_x512 PoseNet_dla_17_x512
age_new age_new
plat_isface plat_isface
Q_hand_0812.pb
Q_dila-small-mix-full-fineturn-390000-nopixel-nosigmoid.pb

+ 1
- 0
mindspore/lite/tools/cropper/build_cropper_config.sh View File

@@ -190,6 +190,7 @@ generateOpsList
getCommonFile getCommonFile
# get src/ops # get src/ops
getOpsFile "Registry\(schema::PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/ops" "prototype" & getOpsFile "Registry\(schema::PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/ops" "prototype" &
getOpsFile "REG_POPULATE\(PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/ops" "prototype" &
getOpsFile "REG_INFER\(.*?, PrimType_" "${MINDSPORE_HOME}/mindspore/lite/nnacl/infer" "prototype" & getOpsFile "REG_INFER\(.*?, PrimType_" "${MINDSPORE_HOME}/mindspore/lite/nnacl/infer" "prototype" &
getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat32, PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/runtime/kernel/arm" "kNumberTypeFloat32" & getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat32, PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/runtime/kernel/arm" "kNumberTypeFloat32" &
getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat16, PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/runtime/kernel/arm" "kNumberTypeFloat16" & getOpsFile "REG_KERNEL\(.*?, kNumberTypeFloat16, PrimitiveType_" "${MINDSPORE_HOME}/mindspore/lite/src/runtime/kernel/arm" "kNumberTypeFloat16" &


Loading…
Cancel
Save