diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index b94c275596..f9fc225b3f 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -91,37 +91,48 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { return VectorRef(outputs); } +namespace { +void PushInputTensor(const BaseRef &arg, std::vector *inputs) { + MS_EXCEPTION_IF_NULL(inputs); + if (utils::isa(arg)) { + auto value = utils::cast(arg); + inputs->push_back(value); + } else if (utils::isa(arg)) { + auto value = utils::cast(arg); + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto value_tuple = value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + auto tuple_value = value_tuple->value(); + for (const auto &v : tuple_value) { + PushInputTensor(v, inputs); + } + } else if (value->isa()) { + tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast()); + inputs->push_back(scalar_tensor); + } else { + inputs->push_back(value->cast()); + } + } else if (utils::isa(arg)) { + auto value = utils::cast(arg).object_; + inputs->push_back(py::cast(value)); + } else if (utils::isa(arg)) { + const auto &args_new = utils::cast(arg); + for (const auto &v : args_new) { + PushInputTensor(v, inputs); + } + } else { + MS_LOG(WARNING) << "Invalid input type."; + } +} +} // namespace + VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; // Run graph std::vector inputs; for (const auto &arg : args) { - if (utils::isa(arg)) { - auto value = utils::cast(arg); - inputs.push_back(value); - } else if (utils::isa(arg)) { - auto value = utils::cast(arg); - if (value->isa()) { - (void)std::transform(value->cast()->value().begin(), value->cast()->value().end(), - std::back_inserter(inputs), - [](const ValuePtr &v) { return v->cast(); }); - } else if (value->isa()) { - tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast()); - MS_EXCEPTION_IF_NULL(scalar_tensor); - inputs.push_back(scalar_tensor); - } else { - inputs.push_back(value->cast()); - } - } else if (utils::isa(arg)) { - auto value = utils::cast(arg).object_; - inputs.push_back(py::cast(value)); - } else if (utils::isa(arg)) { - auto args_new = utils::cast(arg); - (void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs), - [](const BaseRef &v) { return utils::cast(v); }); - } else { - MS_LOG(WARNING) << "Invalid input type."; - } + PushInputTensor(arg, &inputs); } VectorRef outputs;