| @@ -1233,6 +1233,27 @@ FuncGraphPtr TensorSlice::ExpandADim(const FuncGraphPtr &ret_graph, const AnfNod | |||||
| return ret_graph; | return ret_graph; | ||||
| } | } | ||||
| FuncGraphPtr TupleGetItemTensor::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { | |||||
| // select indexed item | |||||
| // args: tuple of items, index | |||||
| const std::string op_name = std::string("TupleGetItemTensor"); | |||||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | |||||
| AbstractTuplePtr branches_abs = abstract::CheckArg<AbstractTuple>(op_name, args_spec_list, 0); | |||||
| AbstractBasePtrList branches = branches_abs->elements(); | |||||
| if (branches.size() > 0 && branches[0] != nullptr && branches[0]->isa<AbstractFunction>()) { | |||||
| FuncGraphPtr ret_graph = std::make_shared<FuncGraph>(); | |||||
| ret_graph->set_flags(FUNC_GRAPH_FLAG_CORE, true); | |||||
| AnfNodePtr functions = ret_graph->add_parameter(); | |||||
| auto index = ret_graph->add_parameter(); | |||||
| ret_graph->set_output(ret_graph->NewCNode({NewValueNode(prim::kPrimSwitchLayer), index, functions})); | |||||
| return ret_graph; | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "TupleGetItemTensor does not support to index " << branches_abs->ToString() << "."; | |||||
| } | |||||
| REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | REGISTER_PYBIND_DEFINE(TupleAdd_, ([](const py::module *m) { | ||||
| (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") | (void)py::class_<TupleAdd, MetaFuncGraph, std::shared_ptr<TupleAdd>>(*m, "TupleAdd_") | ||||
| .def(py::init<std::string &>()); | .def(py::init<std::string &>()); | ||||
| @@ -1247,5 +1268,11 @@ REGISTER_PYBIND_DEFINE(TensorSlice_, ([](const py::module *m) { | |||||
| (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | (void)py::class_<TensorSlice, MetaFuncGraph, std::shared_ptr<TensorSlice>>(*m, "TensorSlice_") | ||||
| .def(py::init<std::string &>()); | .def(py::init<std::string &>()); | ||||
| })); | })); | ||||
| REGISTER_PYBIND_DEFINE(TupleGetItemTensor_, ([](const py::module *m) { | |||||
| (void)py::class_<TupleGetItemTensor, MetaFuncGraph, std::shared_ptr<TupleGetItemTensor>>( | |||||
| *m, "TupleGetItemTensor_") | |||||
| .def(py::init<std::string &>()); | |||||
| })); | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -210,6 +210,18 @@ class TensorSlice : public MetaFuncGraph { | |||||
| FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; | FuncGraphPtr ExpandADim(const FuncGraphPtr &ret_graph, const AnfNodePtr &tensor_node) const; | ||||
| }; | }; | ||||
| using TensorSlicePtr = std::shared_ptr<TensorSlice>; | using TensorSlicePtr = std::shared_ptr<TensorSlice>; | ||||
| class TupleGetItemTensor : public MetaFuncGraph { | |||||
| public: | |||||
| explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} | |||||
| ~TupleGetItemTensor() override = default; | |||||
| MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) | |||||
| FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; | |||||
| friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { | |||||
| return lhs.name_ == rhs.name_; | |||||
| } | |||||
| }; | |||||
| using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>; | |||||
| } // namespace prim | } // namespace prim | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -129,22 +129,27 @@ AbstractBasePtr InferImplSwitch(const AnalysisEnginePtr &, const PrimitivePtr &, | |||||
| AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | AbstractBasePtr InferImplSwitchLayer(const AnalysisEnginePtr &, const PrimitivePtr &primitive, | ||||
| const AbstractBasePtrList &args_spec_list) { | const AbstractBasePtrList &args_spec_list) { | ||||
| // Inputs: index, branch | // Inputs: index, branch | ||||
| if (args_spec_list.size() != 2) { | |||||
| MS_LOG(EXCEPTION) << "SwitchLayer evaluator requires 2 parameters, while the input size is " | |||||
| << args_spec_list.size() << "."; | |||||
| } | |||||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(primitive->name(), args_spec_list, 1); | |||||
| const std::string op_name = primitive->name(); | |||||
| abstract::CheckArgsSize(op_name, args_spec_list, 2); | |||||
| (void)CheckArg<AbstractTensor>(op_name, args_spec_list, 0); | |||||
| AbstractTuplePtr branches_abs = CheckArg<AbstractTuple>(op_name, args_spec_list, 1); | |||||
| AbstractBasePtrList branches = branches_abs->elements(); | AbstractBasePtrList branches = branches_abs->elements(); | ||||
| const size_t maximum_layer_num = 1000; | const size_t maximum_layer_num = 1000; | ||||
| if (branches.size() < 0 || branches.size() > maximum_layer_num) { | if (branches.size() < 0 || branches.size() > maximum_layer_num) { | ||||
| MS_EXCEPTION(ValueError) << "SwitchLayer support at least 1 and at most " << maximum_layer_num << " but got " | |||||
| MS_EXCEPTION(ValueError) << op_name << " support at least 1 and at most " << maximum_layer_num << " but got " | |||||
| << branches.size() << " branches."; | << branches.size() << " branches."; | ||||
| } | } | ||||
| MS_EXCEPTION_IF_NULL(branches[0]); | |||||
| for (size_t i = 0; i < branches.size(); i++) { | |||||
| MS_EXCEPTION_IF_NULL(branches[i]); | |||||
| if (!branches[i]->isa<AbstractFunction>()) { | |||||
| MS_LOG(EXCEPTION) << op_name << " requires that the 2th arg be tuple of functions, but got " | |||||
| << branches[i]->ToString() << " as the " << i << "th element."; | |||||
| } | |||||
| } | |||||
| auto b = branches[0]; | auto b = branches[0]; | ||||
| for (size_t i = 1; i < branches.size(); i++) { | for (size_t i = 1; i < branches.size(); i++) { | ||||
| MS_EXCEPTION_IF_NULL(branches[i]); | |||||
| b = b->Join(branches[i]); | b = b->Join(branches[i]); | ||||
| } | } | ||||
| return b; | return b; | ||||
| @@ -18,13 +18,13 @@ | |||||
| """Basic composite operations.""" | """Basic composite operations.""" | ||||
| from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \ | ||||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_ | |||||
| TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_ | |||||
| from ...common import dtype as mstype | from ...common import dtype as mstype | ||||
| from ...common.api import ms_function | from ...common.api import ms_function | ||||
| from .. import functional as F | from .. import functional as F | ||||
| from .. import operations as P | from .. import operations as P | ||||
| __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_] | |||||
| __all__ = [EnvInstance_, TensorSlice_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_] | |||||
| def add_flags(fn, **flags): | def add_flags(fn, **flags): | ||||
| @@ -72,6 +72,28 @@ _tensor_slice = _TensorSlice('tensor_slice') | |||||
| """_tensor_slice is an metafuncgraph object which will slice a tensor.""" | """_tensor_slice is an metafuncgraph object which will slice a tensor.""" | ||||
| class _TupleGetItemTensor(base.TupleGetItemTensor_): | |||||
| """ | |||||
| Getting item of tuple by tensor index. | |||||
| Inputs: | |||||
| data (tuple): A tuple of items. | |||||
| index (Tensor): The index in tensor. | |||||
| Outputs: | |||||
| Type, is same as the element type of data. | |||||
| """ | |||||
| def __init__(self, name): | |||||
| base.TupleGetItemTensor_.__init__(self, name) | |||||
| def __call__(self, *args): | |||||
| pass | |||||
| _tuple_get_item_tensor = _TupleGetItemTensor('tuple_get_item_tensor') | |||||
| """_tuple_get_item_tensor is an metafuncgraph object which will select indexed item.""" | |||||
| @getitem.register("Tuple", "Number") | @getitem.register("Tuple", "Number") | ||||
| def _tuple_getitem_by_number(data, number_index): | def _tuple_getitem_by_number(data, number_index): | ||||
| """ | """ | ||||
| @@ -102,6 +124,21 @@ def _tuple_getitem_by_slice(data, slice_index): | |||||
| return _tuple_slice(data, slice_index) | return _tuple_slice(data, slice_index) | ||||
| @getitem.register("Tuple", "Tensor") | |||||
| def _tuple_getitem_by_tensor(data, tensor_index): | |||||
| """ | |||||
| Getting item out of tuple by tensor index. | |||||
| Inputs: | |||||
| data (tuple): A tuple of items to index. | |||||
| tensor_index (Tensor): Index to select item. | |||||
| Outputs: | |||||
| Type, is same as the element type of data. | |||||
| """ | |||||
| return _tuple_get_item_tensor(data, tensor_index) | |||||
| @getitem.register("List", "Number") | @getitem.register("List", "Number") | ||||
| def _list_getitem_by_number(data, number_index): | def _list_getitem_by_number(data, number_index): | ||||
| """ | """ | ||||
| @@ -387,7 +387,38 @@ def test_switch_layer(): | |||||
| ret = F.switch_layer(index, self.layers)(x) * self.z3 | ret = F.switch_layer(index, self.layers)(x) * self.z3 | ||||
| return ret | return ret | ||||
| index = Tensor(0) | |||||
| net = SwitchLayerCell() | net = SwitchLayerCell() | ||||
| net(1, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_all(net)(0, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| def test_index_to_switch_layer(): | |||||
| class Layer1(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Layer1, self).__init__() | |||||
| self.z1 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z1') | |||||
| def construct(self, x): | |||||
| return x * self.z1 | |||||
| class Layer2(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Layer2, self).__init__() | |||||
| self.z2 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z2') | |||||
| def construct(self, x): | |||||
| return x * self.z2 | |||||
| class SwitchLayerCell(nn.Cell): | |||||
| def __init__(self): | |||||
| super(SwitchLayerCell, self).__init__() | |||||
| self.layers = (Layer1(), Layer2()) | |||||
| self.z3 = Parameter(Tensor(np.full([128, 96], 0.6, dtype=np.float32)), name='z3') | |||||
| def construct(self, index, x): | |||||
| ret = self.layers[index](x) * self.z3 | |||||
| return ret | |||||
| index = Tensor(0) | |||||
| net = SwitchLayerCell() | |||||
| net(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_by_list(net, ParameterTuple(net.trainable_params()))(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||
| C.grad_all(net)(index, Tensor(np.full([128, 96], 0.6, dtype=np.float32))) | |||||