/** * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). * * Copyright 2019 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "operator/composite/zip_operation.h" #include #include #include "pipeline/static_analysis/abstract_value.h" #include "ir/anf.h" #include "pipeline/static_analysis/dshape.h" #include "pipeline/static_analysis/param_validator.h" #include "operator/cc_implementations.h" #include "optimizer/opt.h" #include "utils/symbolic.h" #include "./common.h" #include "pybind_api/api_register.h" namespace mindspore { // namespace to support composite operators definition namespace prim { using mindspore::abstract::AbstractBase; using mindspore::abstract::AbstractTuple; FuncGraphPtr ZipOperation::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) { // zip operation: // input: tuple arguments // output: tuple of items of input iterated on every input if (args_spec_list.size() == 0) { MS_LOG(EXCEPTION) << "zip arguments input should not be empty"; } auto is_all_tuple = std::all_of(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &abs) -> bool { MS_EXCEPTION_IF_NULL(abs); return abs->isa(); }); if (!is_all_tuple) { MS_LOG(EXCEPTION) << "zip input args should be tuple"; } auto min_abs = std::min_element(args_spec_list.begin(), args_spec_list.end(), [](const AbstractBasePtr &x, const AbstractBasePtr &y) { return (x->cast()->size() < y->cast()->size()); }); FuncGraphPtr ret_graph = std::make_shared(); ret_graph->set_flag(FUNC_GRAPH_FLAG_CORE, true); for (size_t idx = 0; idx < args_spec_list.size(); idx++) { (void)ret_graph->add_parameter(); } // generate tuple output of ziped arguments input std::vector make_tuple_nodes; make_tuple_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t idx = 0; idx < (*min_abs)->cast()->size(); idx++) { std::vector make_tuple_zip_nodes; make_tuple_zip_nodes.push_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t arg_idx = 0; arg_idx < args_spec_list.size(); arg_idx++) { std::vector tuple_get_item_nodes{NewValueNode(prim::kPrimTupleGetItem), ret_graph->parameters()[arg_idx], NewValueNode(SizeToInt(idx))}; auto tuple_get_item_op = ret_graph->NewCNode(tuple_get_item_nodes); make_tuple_zip_nodes.push_back(tuple_get_item_op); } auto make_tuple_zip_op = ret_graph->NewCNode(make_tuple_zip_nodes); make_tuple_nodes.push_back(make_tuple_zip_op); } ret_graph->set_output(ret_graph->NewCNode(make_tuple_nodes)); return ret_graph; } REGISTER_PYBIND_DEFINE(ZipOperation_, ([](const py::module *m) { (void)py::class_>(*m, "ZipOperation_") .def(py::init()); })); } // namespace prim } // namespace mindspore