Merge pull request !1928 from yoonlee666/deletereshapetags/v0.5.0-beta
| @@ -50,11 +50,15 @@ class ReshapeSameShapeEliminater : public AnfVisitor { | |||||
| } | } | ||||
| auto src_shape = src_shape_abs->GetShapeTrack(); | auto src_shape = src_shape_abs->GetShapeTrack(); | ||||
| auto tgt_shape = GetValueNode(shape_); | |||||
| if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa<Shape>()) { | |||||
| auto elements = GetValue<std::vector<int>>(tgt_shape); | |||||
| auto tgt_shape_abs = node->abstract(); | |||||
| if (tgt_shape_abs == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto tgt_shape = tgt_shape_abs->GetShapeTrack(); | |||||
| if (src_shape != nullptr && tgt_shape != nullptr && src_shape->isa<Shape>() && tgt_shape->isa<Shape>()) { | |||||
| auto elements = tgt_shape->cast<ShapePtr>(); | |||||
| auto shape = src_shape->cast<ShapePtr>(); | auto shape = src_shape->cast<ShapePtr>(); | ||||
| if (shape->shape() == elements) { | |||||
| if (shape->shape() == elements->shape()) { | |||||
| return x_; | return x_; | ||||
| } | } | ||||
| } | } | ||||
| @@ -219,6 +219,7 @@ TEST_F(TestOptLib, test_elim_reshape_same_shape) { | |||||
| tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); | tensor::TensorPtr x_tensor = std::make_shared<tensor::Tensor>(kFloat32->type_id(), shp); | ||||
| auto x_abstract = x_tensor->ToAbstract(); | auto x_abstract = x_tensor->ToAbstract(); | ||||
| x_node->set_abstract(x_abstract); | x_node->set_abstract(x_abstract); | ||||
| before->output()->set_abstract(x_abstract); | |||||
| } | } | ||||
| auto patterns = std::vector<SubstitutionPtr>({irpass.reshape_eliminate_}); | auto patterns = std::vector<SubstitutionPtr>({irpass.reshape_eliminate_}); | ||||
| ASSERT_TRUE(CheckOpt(before, after, patterns)); | ASSERT_TRUE(CheckOpt(before, after, patterns)); | ||||