|
|
|
@@ -111,15 +111,15 @@ int Scheduler::InferShape(const lite::Model *model, std::vector<Tensor *> *tenso |
|
|
|
MS_LOG(ERROR) << "Op " << node->name_ << " should exist in model!"; |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
STATUS ret = RET_INFER_INVALID; |
|
|
|
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](Tensor *tensor) { |
|
|
|
bool infer_valid = std::all_of(inputs.begin(), inputs.end(), [](const Tensor *tensor) { |
|
|
|
auto shape = tensor->shape(); |
|
|
|
return std::all_of(shape.begin(), shape.end(), [](int dim) { return dim != -1; }); |
|
|
|
return std::all_of(shape.begin(), shape.end(), [](const int dim) { return dim != -1; }); |
|
|
|
}); |
|
|
|
if (infer_valid) { |
|
|
|
primitive->set_infer_flag(!infer_shape_interrupt); |
|
|
|
ret = primitive->InferShape(inputs, outputs); |
|
|
|
if (!infer_valid) { |
|
|
|
infer_shape_interrupt = true; |
|
|
|
} |
|
|
|
primitive->set_infer_flag(!infer_shape_interrupt); |
|
|
|
auto ret = primitive->InferShape(inputs, outputs); |
|
|
|
if (ret == RET_INFER_INVALID) { |
|
|
|
MS_LOG(INFO) << "InferShape shouldn't be done before runtime, name: " << node->name_ |
|
|
|
<< ", type: " << schema::EnumNamePrimitiveType(static_cast<schema::PrimitiveType>(primitive->Type())) |
|
|
|
|