|
|
|
@@ -15,13 +15,15 @@ |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <vector> |
|
|
|
#include "ops/while.h" |
|
|
|
#include "ops/op_utils.h" |
|
|
|
#include "tools/converter/ops/while.h" |
|
|
|
#include "utils/check_convert_utils.h" |
|
|
|
#include "abstract/primitive_infer_map.h" |
|
|
|
|
|
|
|
constexpr auto kCondSubgraphIndex = "cond_subgraph_index"; |
|
|
|
constexpr auto kBodySubgraphIndex = "body_subgraph_index"; |
|
|
|
|
|
|
|
namespace mindspore { |
|
|
|
namespace ops { |
|
|
|
namespace lite { |
|
|
|
void While::Init(const int64_t cond_subgraph_index, const int64_t body_subgraph_index) { |
|
|
|
this->set_cond_subgraph_index(cond_subgraph_index); |
|
|
|
this->set_body_subgraph_index(body_subgraph_index); |
|
|
|
@@ -44,6 +46,7 @@ int64_t While::get_body_subgraph_index() const { |
|
|
|
auto value_ptr = this->GetAttr(kBodySubgraphIndex); |
|
|
|
return GetValue<int64_t>(value_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, |
|
|
|
const std::vector<AbstractBasePtr> &input_args) { |
|
|
|
MS_EXCEPTION_IF_NULL(primitive); |
|
|
|
@@ -58,6 +61,5 @@ AbstractBasePtr WhileInfer(const abstract::AnalysisEnginePtr &, const PrimitiveP |
|
|
|
} |
|
|
|
return std::make_shared<abstract::AbstractTuple>(output); |
|
|
|
} |
|
|
|
REGISTER_PRIMITIVE_C(kNameWhile, While); |
|
|
|
} // namespace ops |
|
|
|
} // namespace lite |
|
|
|
} // namespace mindspore |