Browse Source

"GE TensorArray ops support"

feature/build-system-rewrite
陈劢 4 years ago
parent
commit
6fc6476066
13 changed files with 505 additions and 1 deletions
  1. +8
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.cc
  2. +4
    -0
      mindspore/ccsrc/frontend/optimizer/irpass.h
  3. +119
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/ge_specialized_prepare.cc
  4. +67
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/ge_specialized_prepare.h
  5. +121
    -0
      mindspore/ccsrc/frontend/optimizer/irpass/ge_tensor_array.h
  6. +3
    -0
      mindspore/ccsrc/pipeline/jit/action.cc
  7. +31
    -0
      mindspore/ccsrc/pipeline/jit/pass.cc
  8. +1
    -0
      mindspore/ccsrc/pipeline/jit/pass.h
  9. +4
    -1
      mindspore/ccsrc/transform/graph_ir/op_adapter_map.h
  10. +42
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/data_flow_ops_declare.cc
  11. +35
    -0
      mindspore/ccsrc/transform/graph_ir/op_declare/data_flow_ops_declare.h
  12. +3
    -0
      mindspore/core/base/core_ops.h
  13. +67
    -0
      mindspore/python/mindspore/ops/operations/_tensor_array.py

+ 8
- 0
mindspore/ccsrc/frontend/optimizer/irpass.cc View File

@@ -50,6 +50,7 @@
#include "frontend/optimizer/irpass/call_graph_tuple_transform.h"
#include "frontend/optimizer/irpass/recompute_prepare.h"
#include "frontend/optimizer/irpass/real_op_eliminate.h"
#include "frontend/optimizer/irpass/ge_tensor_array.h"

namespace mindspore {
namespace opt {
@@ -273,6 +274,13 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
// Workaround
stop_gradient_special_op_ =
MakeSubstitution(std::make_shared<StopGradientSpecialOp>(), "stop_gradient_special_op", prim::kPrimBiasAddGrad);

// ge_tensor_array_link_flow
ge_tensor_array_add_flow_ = MakeSubstitution(std::make_shared<GeTensorArrayAddFlow>(), "ge_tensor_array_add_flow",
{prim::kPrimTensorArrayWrite, prim::kPrimTensorArrayGather});
// ge_tensor_array_cast_index
ge_tensor_array_cast_index_ = MakeSubstitution(std::make_shared<GeTensorArrayCastIndex>(),
"ge_tensor_array_cast_index", prim::kPrimTensorArrayWrite);
}

ResolveIRPassLib::ResolveIRPassLib() {


+ 4
- 0
mindspore/ccsrc/frontend/optimizer/irpass.h View File

@@ -167,6 +167,10 @@ class OptimizeIRPassLib {

// Workaround
SubstitutionPtr stop_gradient_special_op_;

// ge TensorArray process
SubstitutionPtr ge_tensor_array_add_flow_;
SubstitutionPtr ge_tensor_array_cast_index_;
};

// the collection of irpass for resolve action


+ 119
- 0
mindspore/ccsrc/frontend/optimizer/irpass/ge_specialized_prepare.cc View File

@@ -0,0 +1,119 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "frontend/optimizer/irpass/ge_specialized_prepare.h"

#include <memory>
#include <utility>
#include <unordered_map>

#include "ir/func_graph.h"
#include "frontend/operator/ops.h"

namespace mindspore {
namespace opt {
namespace irpass {
void GeTensorArrayPrepare::InsertFlowOutputToTA(const std::vector<AnfNodePtr> &all_nodes) {
FuncGraphPtr root = nullptr;
if (all_nodes.size() == 0) {
return;
} else {
root = all_nodes[0]->func_graph();
}

for (auto &ta_input_node : all_nodes) {
if (!ta_input_node->isa<CNode>()) {
continue;
}
auto ta_input_cnode = ta_input_node->cast<CNodePtr>();
for (size_t input_index = 0; input_index < ta_input_cnode->inputs().size(); input_index++) {
auto ta_node = ta_input_cnode->input(input_index);
if (IsPrimitiveCNode(ta_node, prim::kPrimTensorArray)) {
auto ta_find = converted_ta_node_.find(ta_node);
// cached TensorArray node
if (ta_find != converted_ta_node_.end()) {
auto new_ta_input_node_input = ta_find->second;
ta_input_cnode->set_input(input_index, new_ta_input_node_input);
} else {
// new a TupleGetItem node and set it's input with TensorArray node and ValueNode(0)
// set TAInput node input with TupleGetItem node
int64_t index = 0;

auto index_value_node = NewValueNode(index);
auto index_node_abstract = std::make_shared<abstract::AbstractScalar>(index);
index_value_node->set_abstract(index_node_abstract);

auto new_tuple_get_cnode = root->NewCNode({NewValueNode(prim::kPrimTupleGetItem), ta_node, index_value_node});
auto new_tuple_get_node = new_tuple_get_cnode->cast<AnfNodePtr>();

auto tuple_get_node_abstract = ta_node_abstract_cache_[ta_node];
new_tuple_get_node->set_abstract(tuple_get_node_abstract);

converted_ta_node_[ta_node] = new_tuple_get_node;
ta_input_cnode->set_input(input_index, new_tuple_get_node);
}
}
}
}
}

void GeTensorArrayPrepare::TransformTASizeFromAttrToInput(const AnfNodePtr &node) {
auto ta_node = node->cast<CNodePtr>();
int32_t res_size = 0;
PrimitivePtr prim = GetValueNode<PrimitivePtr>(ta_node->input(0));
// get size attr
if (prim->HasAttr("size")) {
auto size_value_ptr = prim->GetAttr("size");
auto size = GetValue<int64_t>(size_value_ptr);
res_size = static_cast<int32_t>(size);
}
// generate size input
auto size_node = NewValueNode(MakeValue(res_size));
auto node_abstract = std::make_shared<abstract::AbstractScalar>(res_size);
size_node->set_abstract(node_abstract);
auto origin_inputs = ta_node->inputs();
// set cnode input
ta_node->add_input(size_node);
// has monad input
if (origin_inputs.size() > 1) {
std::vector<AnfNodePtr> sorted_inputs(origin_inputs);
sorted_inputs.insert(sorted_inputs.begin() + 1, size_node);
ta_node->set_inputs(sorted_inputs);
}

// get origin abstract
auto origin_ta_abstract = ta_node->abstract();
// new tuple abstract
std::vector<AbstractBasePtr> abstract_list;
// push origin abstract
abstract_list.push_back(origin_ta_abstract);
// new flow abstract
float flow_value = 0.0;
auto flow_abstract = std::make_shared<abstract::AbstractScalar>(flow_value);
// push flow abstract
abstract_list.push_back(flow_abstract);
// cache TensorArray node's abstract
auto abstract_find = ta_node_abstract_cache_.find(ta_node);
if (abstract_find == ta_node_abstract_cache_.end()) {
ta_node_abstract_cache_[ta_node] = ta_node->abstract();
}
// modify TensorArray node output's abstract from Tensor to Tuple
auto new_ta_abstract = std::make_shared<abstract::AbstractTuple>(abstract_list);
ta_node->set_abstract(new_ta_abstract);
}
} // namespace irpass
} // namespace opt
} // namespace mindspore

+ 67
- 0
mindspore/ccsrc/frontend/optimizer/irpass/ge_specialized_prepare.h View File

@@ -0,0 +1,67 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_SPECIALIZED_PREPARE_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_SPECIALIZED_PREPARE_H_

#include <vector>
#include <algorithm>
#include <unordered_map>

#include "ir/func_graph.h"
#include "ir/func_graph_cloner.h"
#include "frontend/optimizer/optimizer_caller.h"
#include "ir/pattern_matcher.h"
#include "frontend/operator/ops.h"
#include "frontend/optimizer/irpass.h"

namespace mindspore {
namespace opt {
namespace irpass {
class GeTensorArrayPrepare {
public:
GeTensorArrayPrepare() = default;
virtual ~GeTensorArrayPrepare() = default;

bool operator()(const FuncGraphPtr &root, const OptimizerPtr &optimizer) {
AnfNodePtr ret = root->get_return();
MS_EXCEPTION_IF_NULL(ret);
std::vector<AnfNodePtr> all_nodes = DeepScopedGraphSearch(ret);

bool change = false;
for (auto &node : all_nodes) {
if (IsPrimitiveCNode(node, prim::kPrimTensorArray)) {
TransformTASizeFromAttrToInput(node);
change = true;
}
}
if (change) {
InsertFlowOutputToTA(all_nodes);
}
return change;
}

private:
// Add a const input with value `size` to TensorArray node
void TransformTASizeFromAttrToInput(const AnfNodePtr &node);
void InsertFlowOutputToTA(const std::vector<AnfNodePtr> &all_nodes);
std::unordered_map<AnfNodePtr, AnfNodePtr> converted_ta_node_;
std::unordered_map<AnfNodePtr, AbstractBasePtr> ta_node_abstract_cache_;
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // #ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_SPECIALIZED_PREPARE_H_

+ 121
- 0
mindspore/ccsrc/frontend/optimizer/irpass/ge_tensor_array.h View File

@@ -0,0 +1,121 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_TENSOR_ARRAY_H_
#define MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_TENSOR_ARRAY_H_

#include <vector>
#include <memory>
#include <algorithm>

#include "frontend/optimizer/irpass.h"
#include "frontend/optimizer/optimizer.h"
#include "frontend/optimizer/anf_visitor.h"
#include "frontend/operator/ops.h"

namespace mindspore {
namespace opt {
namespace irpass {
class GeTensorArrayAddFlow : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTensorArrayWrite, {IsNode, IsNode, IsNode, IsNode})(node);
AnfVisitor::Match(prim::kPrimTensorArrayGather, {IsNode, IsNode, IsNode})(node);

// Check if the pattern matches.
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}

auto ta_node = node->cast<CNodePtr>();
float flow_value = 0.0;
// generate flow input
auto flow_node = NewValueNode(MakeValue(flow_value));
// set abstract
auto node_abstract = std::make_shared<abstract::AbstractScalar>(flow_value);
flow_node->set_abstract(node_abstract);
// add cnode input
auto ta_node_inputs = ta_node->inputs();
if (HasAbstractMonad(ta_node_inputs.back())) {
auto input_size = ta_node_inputs.size();
std::vector<AnfNodePtr> new_inputs;
new_inputs.assign(ta_node_inputs.begin(), ta_node_inputs.end());
new_inputs.insert(new_inputs.begin() + input_size - 1, flow_node);
ta_node->set_inputs(new_inputs);
} else {
ta_node->add_input(flow_node);
}
return ta_node;
}

void Visit(const AnfNodePtr &node) override { is_match_ = true; }

void Reset() { is_match_ = false; }

private:
bool is_match_{false};
};

class GeTensorArrayCastIndex : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
Reset();
AnfVisitor::Match(prim::kPrimTensorArrayWrite, {IsNode, IsNode, IsNode, IsNode, IsNode})(node);

// Check if the pattern matches.
if (!is_match_ || node->func_graph() == nullptr) {
return nullptr;
}

const size_t index_input_index = 2;
auto index_input_node = node->cast<CNodePtr>()->input(index_input_index);
// Get cast prim
auto cast_primitive = std::make_shared<Primitive>(prim::kPrimCast->name());

TypePtr src_type = TypeIdToType(TypeId::kNumberTypeInt64);
TypePtr dst_type = TypeIdToType(TypeId::kNumberTypeInt32);
auto src_attr_value = MakeValue(src_type);
auto dst_attr_value = MakeValue(dst_type);
auto prim = std::make_shared<Primitive>(cast_primitive->AddAttr("dst_type", dst_attr_value));
prim = std::make_shared<Primitive>(prim->AddAttr("DstT", dst_attr_value));
prim = std::make_shared<Primitive>(prim->AddAttr("SrcT", src_attr_value));

// Insert cast
auto type_node = NewValueNode(dst_type);
type_node->set_abstract(dst_type->ToAbstract());

auto new_node = node->func_graph()->NewCNode({NewValueNode(prim), index_input_node, type_node});
auto cast_abstract = index_input_node->abstract();
cast_abstract->set_type(dst_type);
new_node->set_abstract(cast_abstract);

auto cnode = node->cast<CNodePtr>();
cnode->set_input(index_input_index, new_node);
return node;
}

void Visit(const AnfNodePtr &node) override { is_match_ = true; }

void Reset() { is_match_ = false; }

private:
bool is_match_{false};
};
} // namespace irpass
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_FRONTEND_OPTIMIZER_IRPASS_GE_TENSOR_ARRAY_H_

+ 3
- 0
mindspore/ccsrc/pipeline/jit/action.cc View File

@@ -1169,6 +1169,8 @@ bool PipelineSplitAction(const ResourcePtr &res) { return PipelineSplitPass(res)

bool ValidateAction(const ResourcePtr &res) { return ValidatePass(res); }

bool GeSpecializedAction(const ResourcePtr &res) { return GeSpecializedPass(res); }

bool SetMindIRGraphAction(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
res->set_is_load(true);
@@ -1348,6 +1350,7 @@ std::vector<ActionItem> GePipeline() {
(void)actions.emplace_back(std::make_pair("py_opt", OptActionGePyStub));
(void)actions.emplace_back(std::make_pair("remove_value_node_duplications", RemoveValueNodeDuplicationsAction));
(void)actions.emplace_back(std::make_pair("auto_monad_reorder", OrderEnforceAction));
(void)actions.emplace_back(std::make_pair("ge_specialized_prepare", GeSpecializedAction));
(void)actions.emplace_back(std::make_pair("validate", ValidateAction));
return actions;
}


+ 31
- 0
mindspore/ccsrc/pipeline/jit/pass.cc View File

@@ -48,6 +48,7 @@
#include "pipeline/jit/static_analysis/auto_monad.h"
#include "frontend/optimizer/irpass/branch_culling.h"
#include "frontend/optimizer/irpass/meta_fg_eliminate.h"
#include "frontend/optimizer/irpass/ge_specialized_prepare.h"
#include "frontend/optimizer/irpass/parameter_eliminate.h"
#include "frontend/optimizer/irpass/updatestate_eliminate.h"
#if ((defined ENABLE_CPU) && (!defined _WIN32))
@@ -294,6 +295,13 @@ opt::OptPassConfig GetOptPassA1(const opt::irpass::OptimizeIRPassLib &irpass) {
});
}

opt::OptPassConfig GetGeTensorArrayPass(const opt::irpass::OptimizeIRPassLib &irpass) {
return opt::OptPassConfig({
irpass.ge_tensor_array_add_flow_,
irpass.ge_tensor_array_cast_index_,
});
}

OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) {
opt::OptPassConfig a_1 = GetOptPassA1(irpass);
opt::OptPassConfig a_2 = opt::OptPassConfig(
@@ -493,6 +501,17 @@ OptPassGroupMap GetControlPhases(const opt::irpass::OptimizeIRPassLib &) {
return map;
}

OptPassGroupMap GetGeSpecializedPhases() {
opt::OptPassConfig ge_ta_size_group = opt::OptPassConfig(opt::irpass::GeTensorArrayPrepare());
opt::irpass::OptimizeIRPassLib irpass;
opt::OptPassConfig ge_tensor_array_passes = GetGeTensorArrayPass(irpass);
OptPassGroupMap map({
{"ge_ta_size_group", ge_ta_size_group},
{"ge_ta_passes", ge_tensor_array_passes},
});
return map;
}

OptPassGroupMap GetOptPynativeGradEpiloguePhases(const opt::irpass::OptimizeIRPassLib &irpass) {
auto opt_a = GetOptPassesA(irpass);
auto a3 = opt_a[opt_a.size() - 1];
@@ -672,6 +691,18 @@ bool CconvPass(const ResourcePtr &res) {

bool PipelineSplitPass(const ResourcePtr &res) { return PipelineSplit(res); }

bool GeSpecializedPass(const ResourcePtr &res) {
// valid null ptr
MS_EXCEPTION_IF_NULL(res);
FuncGraphPtr func_graph = res->func_graph();
MS_EXCEPTION_IF_NULL(func_graph);
// get phases
auto ge_specialized_map = GetGeSpecializedPhases();
auto ge_specialized_opt = opt::Optimizer::MakeOptimizer("ge_specialized", res, ge_specialized_map, true);
(void)ge_specialized_opt->step(func_graph, false);
return true;
}

bool ValidatePass(const ResourcePtr &res) {
MS_EXCEPTION_IF_NULL(res);
MS_EXCEPTION_IF_NULL(res->func_graph());


+ 1
- 0
mindspore/ccsrc/pipeline/jit/pass.h View File

@@ -41,6 +41,7 @@ extern std::vector<PassItem> kPynativePasses;
bool CconvPass(const ResourcePtr &res);
bool PipelineSplitPass(const ResourcePtr &res);
bool ValidatePass(const ResourcePtr &res);
bool GeSpecializedPass(const ResourcePtr &res);
bool ConvertPrepareAdapt(const ResourcePtr &res);
bool AddCacheEmbeddingPass(const ResourcePtr &res);
bool InferenceOptPreparePass(const ResourcePtr &res);


+ 4
- 1
mindspore/ccsrc/transform/graph_ir/op_adapter_map.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2019-2021 Huawei Technologies Co., Ltd
* Copyright 2019-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -343,6 +343,9 @@ constexpr const char kNameResizeNearestNeighborV2[] = "ResizeNearestNeighborV2";
constexpr const char kNameConv2DBackpropInputV2[] = "Conv2DBackpropInputV2";
constexpr const char kNameConcatV2D[] = "ConcatV2D";
constexpr const char kNameFillV1[] = "FillV1";
constexpr const char kNameTensorArray[] = "TensorArray";
constexpr const char kNameTensorArrayWrite[] = "TensorArrayWrite";
constexpr const char kNameTensorArrayGather[] = "TensorArrayGather";

class OpAdapterMap {
public:


+ 42
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/data_flow_ops_declare.cc View File

@@ -0,0 +1,42 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "transform/graph_ir/op_declare/data_flow_ops_declare.h"
#include <vector>

namespace mindspore::transform {
INPUT_MAP(TensorArray) = {{1, INPUT_DESC(size)}};
ATTR_MAP(TensorArray) = {{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())},
{"element_shape", ATTR_DESC(element_shape, AnyTraits<std::vector<int64_t>>())},
{"dynamic_size", ATTR_DESC(dynamic_size, AnyTraits<bool>())},
{"clear_after_read", ATTR_DESC(clear_after_read, AnyTraits<bool>())},
{"identical_element_shapes", ATTR_DESC(identical_element_shapes, AnyTraits<bool>())},
{"tensor_array_name", ATTR_DESC(tensor_array_name, AnyTraits<std::string>())}};
OUTPUT_MAP(TensorArray) = {{0, OUTPUT_DESC(handle)}, {1, OUTPUT_DESC(flow)}};
REG_ADPT_DESC(TensorArray, kNameTensorArray, ADPT_DESC(TensorArray))

INPUT_MAP(TensorArrayWrite) = {
{1, INPUT_DESC(handle)}, {2, INPUT_DESC(index)}, {3, INPUT_DESC(value)}, {4, INPUT_DESC(flow_in)}};
ATTR_MAP(TensorArrayWrite) = EMPTY_ATTR_MAP;
OUTPUT_MAP(TensorArrayWrite) = {{0, OUTPUT_DESC(flow_out)}};
REG_ADPT_DESC(TensorArrayWrite, kNameTensorArrayWrite, ADPT_DESC(TensorArrayWrite))

INPUT_MAP(TensorArrayGather) = {{1, INPUT_DESC(handle)}, {2, INPUT_DESC(indices)}, {3, INPUT_DESC(flow_in)}};
ATTR_MAP(TensorArrayGather) = {{"dtype", ATTR_DESC(dtype, AnyTraits<GEType>())},
{"element_shape", ATTR_DESC(element_shape, AnyTraits<std::vector<int64_t>>())}};
OUTPUT_MAP(TensorArrayGather) = {{0, OUTPUT_DESC(value)}};
REG_ADPT_DESC(TensorArrayGather, kNameTensorArrayGather, ADPT_DESC(TensorArrayGather))
} // namespace mindspore::transform

+ 35
- 0
mindspore/ccsrc/transform/graph_ir/op_declare/data_flow_ops_declare.h View File

@@ -0,0 +1,35 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_
#define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_

#include <string>
#include <unordered_map>
#include "transform/graph_ir/op_declare/op_declare_macro.h"
#include "ops/data_flow_ops.h"

namespace mindspore::transform {
DECLARE_OP_ADAPTER(TensorArray)
DECLARE_OP_USE_OUTPUT(TensorArray)

DECLARE_OP_ADAPTER(TensorArrayWrite)
DECLARE_OP_USE_OUTPUT(TensorArrayWrite)

DECLARE_OP_ADAPTER(TensorArrayGather)
DECLARE_OP_USE_OUTPUT(TensorArrayGather)
} // namespace mindspore::transform
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_DATA_FLOW_OPS_DECLARE_H_

+ 3
- 0
mindspore/core/base/core_ops.h View File

@@ -932,6 +932,9 @@ MS_CORE_API inline const PrimitivePtr kPrimStandardNormal = std::make_shared<Pri

// RL Ops
MS_CORE_API inline const PrimitivePtr kPrimTensorArrayStack = std::make_shared<Primitive>("TensorArrayStack");
MS_CORE_API inline const PrimitivePtr kPrimTensorArray = std::make_shared<Primitive>("TensorArray");
MS_CORE_API inline const PrimitivePtr kPrimTensorArrayWrite = std::make_shared<Primitive>("TensorArrayWrite");
MS_CORE_API inline const PrimitivePtr kPrimTensorArrayGather = std::make_shared<Primitive>("TensorArrayGather");

class DoSignaturePrimitive : public Primitive {
public:


+ 67
- 0
mindspore/python/mindspore/ops/operations/_tensor_array.py View File

@@ -26,6 +26,9 @@ class TensorArray(PrimitiveWithInfer):
r"""
TensorArrayCreate used to create a TensorArray and return an unique handle.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
@@ -72,6 +75,9 @@ class TensorArrayWrite(PrimitiveWithInfer):
r"""
TensorArrayWrite used to write tensor into a created TensorArray.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Inputs:
- **index** (Tensor[int64]) - The position to write.
- **value** (Tensor) - The value to add into the TensorArray.
@@ -109,6 +115,9 @@ class TensorArrayRead(PrimitiveWithInfer):
r"""
TensorArrayRead used to read tensor from a created TensorArray by the given index.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
@@ -157,6 +166,9 @@ class TensorArrayClose(PrimitiveWithInfer):
r"""
TensorArrayClose used to close the created TensorArray. The resources in TensorArray will be deleted.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.

@@ -190,6 +202,9 @@ class TensorArrayClear(PrimitiveWithInfer):
r"""
TensorArrayClear used to reset the created TensorArray. The instance of TensorArray is still aviliable.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.

@@ -223,6 +238,9 @@ class TensorArrayStack(Primitive):
r"""
TensorArrayStack used to stack the tensors in a created TensorArray into one tensor.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.
@@ -264,6 +282,9 @@ class TensorArraySize(PrimitiveWithInfer):
r"""
TensorArraySize used to get the logical size of the created TensorArray.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.

@@ -291,3 +312,49 @@ class TensorArraySize(PrimitiveWithInfer):
def infer_dtype(self, handle_type):
validator.check_type_name("handle", handle_type, (ms.int64), self.name)
return mstype.int64


class TensorArrayGather(PrimitiveWithInfer):
r"""
TensorArrayGather used to gather specified elements from the created TensorArray.

.. warning::
This is an experimental prototype that is subject to change and/or deletion.

Args:
dtype (mindspore.dtype): the data type in the TensorArray.
element_shape (tuple[int]): the shape of each tensor in a TensorArray.

Inputs:
- **handle** (mindspore.int64) - The handle pointed to the TensorArray.
- **indices** (mindspore.int32) - The locations of the gathered elements.

Outputs:
- **output** (Tensor) - The gathered value from the TensorArray.

Examples:
>>> import mindspore
>>> import mindspore.ops as ops
>>> from mindspore import numpy as mnp
>>> create_op = ops.TensorArray(mindspore.float32, dynamic_size=False, element_shape=(8,))
>>> handle = create_op()
>>> indices = mnp.range(0, 25, 1, mindspore.int32)
>>> gather_op = ops.TensorArrayGather(dtype=mindspore.float32, element_shape=(8,))
>>> gather_result = gather_op(handle, indices)
"""
@prim_attr_register
def __init__(self, dtype, element_shape):
self.init_prim_io_names(inputs=['handle', 'indices'], outputs=['value'])
self.add_prim_attr("side_effect_mem", True)
self.dtype = dtype
self.element_shape = element_shape

def infer_shape(self, handle, indices):
if len(indices) != 1:
return ValueError("indices dimension should be equal to 1")
return [indices[0]] + list(self.element_shape)

def infer_dtype(self, handle, indices):
validator.check_type_name("handle", handle, (ms.int64), self.name)
validator.check_type_name("indices", indices, (ms.int32), self.name)
return self.dtype

Loading…
Cancel
Save