Browse Source

modify dynamic shape condition to fit ME

tags/v1.1.0
wilfChen 5 years ago
parent
commit
07d8622c7e
2 changed files with 6 additions and 6 deletions
  1. +3
    -3
      mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h
  2. +3
    -3
      mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc

+ 3
- 3
mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h View File

@@ -173,9 +173,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
BroadcastOpType op_type_;
bool need_broadcast_;
bool is_comp_op_;
int input1_num_;
int input2_num_;
int output_num_;
size_t input1_num_;
size_t input2_num_;
size_t output_num_;
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;


+ 3
- 3
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc View File

@@ -47,18 +47,18 @@ constexpr size_t kNopNodeRealInputIndex = 1;

bool IsShapeDynamic(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int64_t s) { return s < 0; });
}

bool IsShapeDynamic(const std::vector<size_t> &shape) {
return std::any_of(shape.begin(), shape.end(), [](int s) { return s < 0; });
return std::any_of(shape.begin(), shape.end(), [](int64_t s) { return s < 0; });
}

std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
std::vector<size_t> shape_size_t;
if (IsShapeDynamic(shape)) {
if (std::all_of(shape->max_shape().begin(), shape->max_shape().end(), [](int s) { return s >= 0; })) {
if (std::all_of(shape->max_shape().begin(), shape->max_shape().end(), [](int64_t s) { return s >= 0; })) {
std::transform(shape->max_shape().begin(), shape->max_shape().end(), std::back_inserter(shape_size_t),
LongToSize);
} else {


Loading…
Cancel
Save