| @@ -50,6 +50,7 @@ | |||
| #include "frontend/optimizer/irpass/call_graph_tuple_transform.h" | |||
| #include "frontend/optimizer/irpass/recompute_prepare.h" | |||
| #include "frontend/optimizer/irpass/real_op_eliminate.h" | |||
| #include "frontend/optimizer/irpass/ge_tensor_array.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| @@ -273,6 +274,13 @@ OptimizeIRPassLib::OptimizeIRPassLib() { | |||
| // Workaround | |||
| stop_gradient_special_op_ = | |||
| MakeSubstitution(std::make_shared<StopGradientSpecialOp>(), "stop_gradient_special_op", prim::kPrimBiasAddGrad); | |||
| // ge_tensor_array_link_flow | |||
| ge_tensor_array_add_flow_ = MakeSubstitution(std::make_shared<GeTensorArrayAddFlow>(), "ge_tensor_array_add_flow", | |||
| {prim::kPrimTensorArrayWrite, prim::kPrimTensorArrayGather}); | |||
| // ge_tensor_array_cast_index | |||
| ge_tensor_array_cast_index_ = MakeSubstitution(std::make_shared<GeTensorArrayCastIndex>(), | |||
| "ge_tensor_array_cast_index", prim::kPrimTensorArrayWrite); | |||
| } | |||
| ResolveIRPassLib::ResolveIRPassLib() { | |||
| @@ -167,6 +167,10 @@ class OptimizeIRPassLib { | |||
| // Workaround | |||
| SubstitutionPtr stop_gradient_special_op_; | |||
| // ge TensorArray process | |||
| SubstitutionPtr ge_tensor_array_add_flow_; | |||
| SubstitutionPtr ge_tensor_array_cast_index_; | |||
| }; | |||
| // the collection of irpass for resolve action | |||
| @@ -0,0 +1,119 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "frontend/optimizer/irpass/ge_specialized_prepare.h" | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <unordered_map> | |||
| #include "ir/func_graph.h" | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| void GeTensorArrayPrepare::InsertFlowOutputToTA(const std::vector<AnfNodePtr> &all_nodes) { | |||
| FuncGraphPtr root = nullptr; | |||
| if (all_nodes.size() == 0) { | |||
| return; | |||
| } else { | |||
| root = all_nodes[0]->func_graph(); | |||
| } | |||
| for (auto &ta_input_node : all_nodes) { | |||
| if (!ta_input_node->isa<CNode>()) { | |||
| continue; | |||
| } | |||
| auto ta_input_cnode = ta_input_node->cast<CNodePtr>(); | |||
| for (size_t input_index = 0; input_index < ta_input_cnode->inputs().size(); input_index++) { | |||
| auto ta_node = ta_input_cnode->input(input_index); | |||
| if (IsPrimitiveCNode(ta_node, prim::kPrimTensorArray)) { | |||
| auto ta_find = converted_ta_node_.find(ta_node); | |||
| // cached TensorArray node | |||
| if (ta_find != converted_ta_node_.end()) { | |||
| auto new_ta_input_node_input = ta_find->second; | |||
| ta_input_cnode->set_input(input_index, new_ta_input_node_input); | |||
| } else { | |||
| // new a TupleGetItem node and set it's input with TensorArray node and ValueNode(0) | |||
| // set TAInput node input with TupleGetItem node | |||
| int64_t index = 0; | |||
| auto index_value_node = NewValueNode(index); | |||
| auto index_node_abstract = std::make_shared<abstract::AbstractScalar>(index); | |||
| index_value_node->set_abstract(index_node_abstract); | |||
| auto new_tuple_get_cnode = root->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ta_node, index_value_node}); | |||
| auto new_tuple_get_node = new_tuple_get_cnode->cast<AnfNodePtr>(); | |||
| auto tuple_get_node_abstract = ta_node_abstract_cache_[ta_node]; | |||
| new_tuple_get_node->set_abstract(tuple_get_node_abstract); | |||
| converted_ta_node_[ta_node] = new_tuple_get_node; | |||
| ta_input_cnode->set_input(input_index, new_tuple_get_node); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void GeTensorArrayPrepare::TransformTASizeFromAttrToInput(const AnfNodePtr &node) { | |||
| auto ta_node = node->cast<CNodePtr>(); | |||
| int32_t res_size = 0; | |||
| PrimitivePtr prim = GetValueNode<PrimitivePtr>(ta_node->input(0)); | |||
| // get size attr | |||
| if (prim->HasAttr("size")) { | |||
| auto size_value_ptr = prim->GetAttr("size"); | |||
| auto size = GetValue<int64_t>(size_value_ptr); | |||
| res_size = static_cast<int32_t>(size); | |||
| } | |||
| // generate size input | |||
| auto size_node = NewValueNode(MakeValue(res_size)); | |||
| auto node_abstract = std::make_shared<abstract::AbstractScalar>(res_size); | |||
| size_node->set_abstract(node_abstract); | |||
| auto origin_inputs = ta_node->inputs(); | |||
| // set cnode input | |||
| ta_node->add_input(size_node); | |||
| // has monad input | |||
| if (origin_inputs.size() > 1) { | |||
| std::vector<AnfNodePtr> sorted_inputs(origin_inputs); | |||
| sorted_inputs.insert(sorted_inputs.begin() + 1, size_node); | |||
| ta_node->set_inputs(sorted_inputs); | |||
| } | |||
| // get origin abstract | |||
| auto origin_ta_abstract = ta_node->abstract(); | |||
| // new tuple abstract | |||
| std::vector<AbstractBasePtr> abstract_list; | |||
| // push origin abstract | |||
| abstract_list.push_back(origin_ta_abstract); | |||
| // new flow abstract | |||
| float flow_value = 0.0; | |||
| auto flow_abstract = std::make_shared<abstract::AbstractScalar>(flow_value); | |||
| // push flow abstract | |||
| abstract_list.push_back(flow_abstract); | |||
| // cache TensorArray node's abstract | |||
| auto abstract_find = ta_node_abstract_cache_.find(ta_node); | |||
| if (abstract_find == ta_node_abstract_cache_.end()) { | |||
| ta_node_abstract_cache_[ta_node] = ta_node->abstract(); | |||
| } | |||
| // modify TensorArray node output's abstract from Tensor to Tuple | |||
| auto new_ta_abstract = std::make_shared<abstract::AbstractTuple>(abstract_list); | |||
| ta_node->set_abstract(new_ta_abstract); | |||
| } | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,67 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_SPECIALIZED_PREPARE_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_SPECIALIZED_PREPARE_H_ | |||
| #include <vector> | |||
| #include <algorithm> | |||
| #include <unordered_map> | |||
| #include "ir/func_graph.h" | |||
| #include "ir/func_graph_cloner.h" | |||
| #include "frontend/optimizer/optimizer_caller.h" | |||
| #include "ir/pattern_matcher.h" | |||
| #include "frontend/operator/ops.h" | |||
| #include "frontend/optimizer/irpass.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| class GeTensorArrayPrepare { | |||
| public: | |||
| GeTensorArrayPrepare() = default; | |||
| virtual ~GeTensorArrayPrepare() = default; | |||
| bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) { | |||
| AnfNodePtr ret = root->get_return(); | |||
| MS_EXCEPTION_IF_NULL(ret); | |||
| std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret); | |||
| bool change = false; | |||
| for (auto &node : all_nodes) { | |||
| if (IsPrimitiveCNode(node, prim::kPrimTensorArray)) { | |||
| TransformTASizeFromAttrToInput(node); | |||
| change = true; | |||
| } | |||
| } | |||
| if (change) { | |||
| InsertFlowOutputToTA(all_nodes); | |||
| } | |||
| return change; | |||
| } | |||
| private: | |||
| // Add a const input with value `size` to TensorArray node | |||
| void TransformTASizeFromAttrToInput(const AnfNodePtr &node); | |||
| void InsertFlowOutputToTA(const std::vector<AnfNodePtr> &all_nodes); | |||
| std::unordered_map<AnfNodePtr, AnfNodePtr> converted_ta_node_; | |||
| std::unordered_map<AnfNodePtr, AbstractBasePtr> ta_node_abstract_cache_; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_SPECIALIZED_PREPARE_H_ | |||
| @@ -0,0 +1,121 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_TENSOR_ARRAY_H_ | |||
| #define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_TENSOR_ARRAY_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <algorithm> | |||
| #include "frontend/optimizer/irpass.h" | |||
| #include "frontend/optimizer/optimizer.h" | |||
| #include "frontend/optimizer/anf_visitor.h" | |||
| #include "frontend/operator/ops.h" | |||
| namespace mindspore { | |||
| namespace opt { | |||
| namespace irpass { | |||
| class GeTensorArrayAddFlow : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimTensorArrayWrite, {IsNode, IsNode, IsNode, IsNode})(node); | |||
| AnfVisitor::Match(prim::kPrimTensorArrayGather, {IsNode, IsNode, IsNode})(node); | |||
| // Check if the pattern matches. | |||
| if (!is_match_ || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto ta_node = node->cast<CNodePtr>(); | |||
| float flow_value = 0.0; | |||
| // generate flow input | |||
| auto flow_node = NewValueNode(MakeValue(flow_value)); | |||
| // set abstract | |||
| auto node_abstract = std::make_shared<abstract::AbstractScalar>(flow_value); | |||
| flow_node->set_abstract(node_abstract); | |||
| // add cnode input | |||
| auto ta_node_inputs = ta_node->inputs(); | |||
| if (HasAbstractMonad(ta_node_inputs.back())) { | |||
| auto input_size = ta_node_inputs.size(); | |||
| std::vector<AnfNodePtr> new_inputs; | |||
| new_inputs.assign(ta_node_inputs.begin(), ta_node_inputs.end()); | |||
| new_inputs.insert(new_inputs.begin() + input_size - 1, flow_node); | |||
| ta_node->set_inputs(new_inputs); | |||
| } else { | |||
| ta_node->add_input(flow_node); | |||
| } | |||
| return ta_node; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { is_match_ = true; } | |||
| void Reset() { is_match_ = false; } | |||
| private: | |||
| bool is_match_{false}; | |||
| }; | |||
| class GeTensorArrayCastIndex : public AnfVisitor { | |||
| public: | |||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | |||
| Reset(); | |||
| AnfVisitor::Match(prim::kPrimTensorArrayWrite, {IsNode, IsNode, IsNode, IsNode, IsNode})(node); | |||
| // Check if the pattern matches. | |||
| if (!is_match_ || node->func_graph() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| const size_t index_input_index = 2; | |||
| auto index_input_node = node->cast<CNodePtr>()->input(index_input_index); | |||
| // Get cast prim | |||
| auto cast_primitive = std::make_shared<Primitive>(prim::kPrimCast->name()); | |||
| TypePtr src_type = TypeIdToType(TypeId::kNumberTypeInt64); | |||
| TypePtr dst_type = TypeIdToType(TypeId::kNumberTypeInt32); | |||
| auto src_attr_value = MakeValue(src_type); | |||
| auto dst_attr_value = MakeValue(dst_type); | |||
| auto prim = std::make_shared<Primitive>(cast_primitive->AddAttr("dst_type", dst_attr_value)); | |||
| prim = std::make_shared<Primitive>(prim->AddAttr("DstT", dst_attr_value)); | |||
| prim = std::make_shared<Primitive>(prim->AddAttr("SrcT", src_attr_value)); | |||
| // Insert cast | |||
| auto type_node = NewValueNode(dst_type); | |||
| type_node->set_abstract(dst_type->ToAbstract()); | |||
| auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), index_input_node, type_node}); | |||
| auto cast_abstract = index_input_node->abstract(); | |||
| cast_abstract->set_type(dst_type); | |||
| new_node->set_abstract(cast_abstract); | |||
| auto cnode = node->cast<CNodePtr>(); | |||
| cnode->set_input(index_input_index, new_node); | |||
| return node; | |||
| } | |||
| void Visit(const AnfNodePtr &node) override { is_match_ = true; } | |||
| void Reset() { is_match_ = false; } | |||
| private: | |||
| bool is_match_{false}; | |||
| }; | |||
| } // namespace irpass | |||
| } // namespace opt | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_TENSOR_ARRAY_H_ | |||
| @@ -1169,6 +1169,8 @@ bool PipelineSplitAction(const ResourcePtr &res) { return PipelineSplitPass(res) | |||
| bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); } | |||
| bool GeSpecializedAction(const ResourcePtr &res) { return GeSpecializedPass(res); } | |||
| bool SetMindIRGraphAction(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| res->set_is_load(true); | |||
| @@ -1348,6 +1350,7 @@ std::vector<ActionItem> GePipeline() { | |||
| (void)actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub)); | |||
| (void)actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction)); | |||
| (void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction)); | |||
| (void)actions.emplace_back(std::make_pair("ge_specialized_prepare", GeSpecializedAction)); | |||
| (void)actions.emplace_back(std::make_pair("validate", ValidateAction)); | |||
| return actions; | |||
| } | |||
| @@ -48,6 +48,7 @@ | |||
| #include "pipeline/jit/static_analysis/auto_monad.h" | |||
| #include "frontend/optimizer/irpass/branch_culling.h" | |||
| #include "frontend/optimizer/irpass/meta_fg_eliminate.h" | |||
| #include "frontend/optimizer/irpass/ge_specialized_prepare.h" | |||
| #include "frontend/optimizer/irpass/parameter_eliminate.h" | |||
| #include "frontend/optimizer/irpass/updatestate_eliminate.h" | |||
| #if ((defined ENABLE_CPU) && (!defined _WIN32)) | |||
| @@ -294,6 +295,13 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| }); | |||
| } | |||
| opt::OptPassConfig GetGeTensorArrayPass(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| return opt::OptPassConfig({ | |||
| irpass.ge_tensor_array_add_flow_, | |||
| irpass.ge_tensor_array_cast_index_, | |||
| }); | |||
| } | |||
| OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| opt::OptPassConfig a_1 = GetOptPassA1(irpass); | |||
| opt::OptPassConfig a_2 = opt::OptPassConfig( | |||
| @@ -493,6 +501,17 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &) { | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetGeSpecializedPhases() { | |||
| opt::OptPassConfig ge_ta_size_group = opt::OptPassConfig(opt::irpass::GeTensorArrayPrepare()); | |||
| opt::irpass::OptimizeIRPassLib irpass; | |||
| opt::OptPassConfig ge_tensor_array_passes = GetGeTensorArrayPass(irpass); | |||
| OptPassGroupMap map({ | |||
| {"ge_ta_size_group", ge_ta_size_group}, | |||
| {"ge_ta_passes", ge_tensor_array_passes}, | |||
| }); | |||
| return map; | |||
| } | |||
| OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) { | |||
| auto opt_a = GetOptPassesA(irpass); | |||
| auto a3 = opt_a[opt_a.size() - 1]; | |||
| @@ -672,6 +691,18 @@ bool CconvPass(const ResourcePtr &res) { | |||
| bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); } | |||
| bool GeSpecializedPass(const ResourcePtr &res) { | |||
| // valid null ptr | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| FuncGraphPtr func_graph = res->func_graph(); | |||
| MS_EXCEPTION_IF_NULL(func_graph); | |||
| // get phases | |||
| auto ge_specialized_map = GetGeSpecializedPhases(); | |||
| auto ge_specialized_opt = opt::Optimizer::MakeOptimizer("ge_specialized", res, ge_specialized_map, true); | |||
| (void)ge_specialized_opt->step(func_graph, false); | |||
| return true; | |||
| } | |||
| bool ValidatePass(const ResourcePtr &res) { | |||
| MS_EXCEPTION_IF_NULL(res); | |||
| MS_EXCEPTION_IF_NULL(res->func_graph()); | |||
| @@ -41,6 +41,7 @@ extern std::vector<PassItem> kPynativePasses; | |||
| bool CconvPass(const ResourcePtr &res); | |||
| bool PipelineSplitPass(const ResourcePtr &res); | |||
| bool ValidatePass(const ResourcePtr &res); | |||
| bool GeSpecializedPass(const ResourcePtr &res); | |||
| bool ConvertPrepareAdapt(const ResourcePtr &res); | |||
| bool AddCacheEmbeddingPass(const ResourcePtr &res); | |||
| bool InferenceOptPreparePass(const ResourcePtr &res); | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2019-2021 Huawei Technologies Co., Ltd | |||
| * Copyright 2019-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -343,6 +343,9 @@ constexpr const char kNameResizeNearestNeighborV2[] = "ResizeNearestNeighborV2"; | |||
| constexpr const char kNameConv2DBackpropInputV2[] = "Conv2DBackpropInputV2"; | |||
| constexpr const char kNameConcatV2D[] = "ConcatV2D"; | |||
| constexpr const char kNameFillV1[] = "FillV1"; | |||
| constexpr const char kNameTensorArray[] = "TensorArray"; | |||
| constexpr const char kNameTensorArrayWrite[] = "TensorArrayWrite"; | |||
| constexpr const char kNameTensorArrayGather[] = "TensorArrayGather"; | |||
| class OpAdapterMap { | |||
| public: | |||
| @@ -0,0 +1,42 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "transform/graph_ir/op_declare/data_flow_ops_declare.h" | |||
| #include <vector> | |||
| namespace mindspore::transform { | |||
| INPUT_MAP(TensorArray) = {{1, INPUT_DESC(size)}}; | |||
| ATTR_MAP(TensorArray) = {{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())}, | |||
| {"element_shape", ATTR_DESC(element_shape, AnyTraits<std::vector<int64_t>>())}, | |||
| {"dynamic_size", ATTR_DESC(dynamic_size, AnyTraits<bool>())}, | |||
| {"clear_after_read", ATTR_DESC(clear_after_read, AnyTraits<bool>())}, | |||
| {"identical_element_shapes", ATTR_DESC(identical_element_shapes, AnyTraits<bool>())}, | |||
| {"tensor_array_name", ATTR_DESC(tensor_array_name, AnyTraits<std::string>())}}; | |||
| OUTPUT_MAP(TensorArray) = {{0, OUTPUT_DESC(handle)}, {1, OUTPUT_DESC(flow)}}; | |||
| REG_ADPT_DESC(TensorArray, kNameTensorArray, ADPT_DESC(TensorArray)) | |||
| INPUT_MAP(TensorArrayWrite) = { | |||
| {1, INPUT_DESC(handle)}, {2, INPUT_DESC(index)}, {3, INPUT_DESC(value)}, {4, INPUT_DESC(flow_in)}}; | |||
| ATTR_MAP(TensorArrayWrite) = EMPTY_ATTR_MAP; | |||
| OUTPUT_MAP(TensorArrayWrite) = {{0, OUTPUT_DESC(flow_out)}}; | |||
| REG_ADPT_DESC(TensorArrayWrite, kNameTensorArrayWrite, ADPT_DESC(TensorArrayWrite)) | |||
| INPUT_MAP(TensorArrayGather) = {{1, INPUT_DESC(handle)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(flow_in)}}; | |||
| ATTR_MAP(TensorArrayGather) = {{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())}, | |||
| {"element_shape", ATTR_DESC(element_shape, AnyTraits<std::vector<int64_t>>())}}; | |||
| OUTPUT_MAP(TensorArrayGather) = {{0, OUTPUT_DESC(value)}}; | |||
| REG_ADPT_DESC(TensorArrayGather, kNameTensorArrayGather, ADPT_DESC(TensorArrayGather)) | |||
| } // namespace mindspore::transform | |||
| @@ -0,0 +1,35 @@ | |||
| /** | |||
| * Copyright 2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_ | |||
| #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_ | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include "transform/graph_ir/op_declare/op_declare_macro.h" | |||
| #include "ops/data_flow_ops.h" | |||
| namespace mindspore::transform { | |||
| DECLARE_OP_ADAPTER(TensorArray) | |||
| DECLARE_OP_USE_OUTPUT(TensorArray) | |||
| DECLARE_OP_ADAPTER(TensorArrayWrite) | |||
| DECLARE_OP_USE_OUTPUT(TensorArrayWrite) | |||
| DECLARE_OP_ADAPTER(TensorArrayGather) | |||
| DECLARE_OP_USE_OUTPUT(TensorArrayGather) | |||
| } // namespace mindspore::transform | |||
| #endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_ | |||
| @@ -932,6 +932,9 @@ MS_CORE_API inline const PrimitivePtr kPrimStandardNormal = std::make_shared<Pri | |||
| // RL Ops | |||
| MS_CORE_API inline const PrimitivePtr kPrimTensorArrayStack = std::make_shared<Primitive>("TensorArrayStack"); | |||
| MS_CORE_API inline const PrimitivePtr kPrimTensorArray = std::make_shared<Primitive>("TensorArray"); | |||
| MS_CORE_API inline const PrimitivePtr kPrimTensorArrayWrite = std::make_shared<Primitive>("TensorArrayWrite"); | |||
| MS_CORE_API inline const PrimitivePtr kPrimTensorArrayGather = std::make_shared<Primitive>("TensorArrayGather"); | |||
| class DoSignaturePrimitive : public Primitive { | |||
| public: | |||
| @@ -26,6 +26,9 @@ class TensorArray(PrimitiveWithInfer): | |||
| r""" | |||
| TensorArrayCreate used to create a TensorArray and return an unique handle. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Args: | |||
| dtype (mindspore.dtype): the data type in the TensorArray. | |||
| element_shape (tuple[int]): the shape of each tensor in a TensorArray. | |||
| @@ -72,6 +75,9 @@ class TensorArrayWrite(PrimitiveWithInfer): | |||
| r""" | |||
| TensorArrayWrite used to write tensor into a created TensorArray. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Inputs: | |||
| - **index** (Tensor[int64]) - The position to write. | |||
| - **value** (Tensor) - The value to add into the TensorArray. | |||
| @@ -109,6 +115,9 @@ class TensorArrayRead(PrimitiveWithInfer): | |||
| r""" | |||
| TensorArrayRead used to read tensor from a created TensorArray by the given index. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Args: | |||
| dtype (mindspore.dtype): the data type in the TensorArray. | |||
| element_shape (tuple[int]): the shape of each tensor in a TensorArray. | |||
| @@ -157,6 +166,9 @@ class TensorArrayClose(PrimitiveWithInfer): | |||
| r""" | |||
| TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Inputs: | |||
| - **handle** (mindspore.int64) - The handle pointed to the TensorArray. | |||
| @@ -190,6 +202,9 @@ class TensorArrayClear(PrimitiveWithInfer): | |||
| r""" | |||
| TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Inputs: | |||
| - **handle** (mindspore.int64) - The handle pointed to the TensorArray. | |||
| @@ -223,6 +238,9 @@ class TensorArrayStack(Primitive): | |||
| r""" | |||
| TensorArrayStack used to stack the tensors in a created TensorArray into one tensor. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Args: | |||
| dtype (mindspore.dtype): the data type in the TensorArray. | |||
| element_shape (tuple[int]): the shape of each tensor in a TensorArray. | |||
| @@ -264,6 +282,9 @@ class TensorArraySize(PrimitiveWithInfer): | |||
| r""" | |||
| TensorArraySize used to get the logical size of the created TensorArray. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Inputs: | |||
| - **handle** (mindspore.int64) - The handle pointed to the TensorArray. | |||
| @@ -291,3 +312,49 @@ class TensorArraySize(PrimitiveWithInfer): | |||
| def infer_dtype(self, handle_type): | |||
| validator.check_type_name("handle", handle_type, (ms.int64), self.name) | |||
| return mstype.int64 | |||
| class TensorArrayGather(PrimitiveWithInfer): | |||
| r""" | |||
| TensorArrayGather used to gather specified elements from the created TensorArray. | |||
| .. warning:: | |||
| This is an experimental prototype that is subject to change and/or deletion. | |||
| Args: | |||
| dtype (mindspore.dtype): the data type in the TensorArray. | |||
| element_shape (tuple[int]): the shape of each tensor in a TensorArray. | |||
| Inputs: | |||
| - **handle** (mindspore.int64) - The handle pointed to the TensorArray. | |||
| - **indices** (mindspore.int32) - The locations of the gathered elements. | |||
| Outputs: | |||
| - **output** (Tensor) - The gathered value from the TensorArray. | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import mindspore.ops as ops | |||
| >>> from mindspore import numpy as mnp | |||
| >>> create_op = ops.TensorArray(mindspore.float32, dynamic_size=False, element_shape=(8,)) | |||
| >>> handle = create_op() | |||
| >>> indices = mnp.range(0, 25, 1, mindspore.int32) | |||
| >>> gather_op = ops.TensorArrayGather(dtype=mindspore.float32, element_shape=(8,)) | |||
| >>> gather_result = gather_op(handle, indices) | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, dtype, element_shape): | |||
| self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value']) | |||
| self.add_prim_attr("side_effect_mem", True) | |||
| self.dtype = dtype | |||
| self.element_shape = element_shape | |||
| def infer_shape(self, handle, indices): | |||
| if len(indices) != 1: | |||
| return ValueError("indices dimension should be equal to 1") | |||
| return [indices[0]] + list(self.element_shape) | |||
| def infer_dtype(self, handle, indices): | |||
| validator.check_type_name("handle", handle, (ms.int64), self.name) | |||
| validator.check_type_name("indices", indices, (ms.int32), self.name) | |||
| return self.dtype | |||