| @@ -91,37 +91,48 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { | |||
| return VectorRef(outputs); | |||
| } | |||
| namespace { | |||
| void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) { | |||
| MS_EXCEPTION_IF_NULL(inputs); | |||
| if (utils::isa<tensor::TensorPtr>(arg)) { | |||
| auto value = utils::cast<tensor::TensorPtr>(arg); | |||
| inputs->push_back(value); | |||
| } else if (utils::isa<ValuePtr>(arg)) { | |||
| auto value = utils::cast<ValuePtr>(arg); | |||
| MS_EXCEPTION_IF_NULL(value); | |||
| if (value->isa<ValueTuple>()) { | |||
| auto value_tuple = value->cast<ValueTuplePtr>(); | |||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||
| auto tuple_value = value_tuple->value(); | |||
| for (const auto &v : tuple_value) { | |||
| PushInputTensor(v, inputs); | |||
| } | |||
| } else if (value->isa<Scalar>()) { | |||
| tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>()); | |||
| inputs->push_back(scalar_tensor); | |||
| } else { | |||
| inputs->push_back(value->cast<tensor::TensorPtr>()); | |||
| } | |||
| } else if (utils::isa<PyObjectRef>(arg)) { | |||
| auto value = utils::cast<PyObjectRef>(arg).object_; | |||
| inputs->push_back(py::cast<tensor::TensorPtr>(value)); | |||
| } else if (utils::isa<VectorRefPtr>(arg)) { | |||
| const auto &args_new = utils::cast<VectorRef>(arg); | |||
| for (const auto &v : args_new) { | |||
| PushInputTensor(v, inputs); | |||
| } | |||
| } else { | |||
| MS_LOG(WARNING) << "Invalid input type."; | |||
| } | |||
| } | |||
| } // namespace | |||
| VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { | |||
| MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; | |||
| // Run graph | |||
| std::vector<tensor::TensorPtr> inputs; | |||
| for (const auto &arg : args) { | |||
| if (utils::isa<tensor::TensorPtr>(arg)) { | |||
| auto value = utils::cast<tensor::TensorPtr>(arg); | |||
| inputs.push_back(value); | |||
| } else if (utils::isa<ValuePtr>(arg)) { | |||
| auto value = utils::cast<ValuePtr>(arg); | |||
| if (value->isa<ValueTuple>()) { | |||
| (void)std::transform(value->cast<ValueTuplePtr>()->value().begin(), value->cast<ValueTuplePtr>()->value().end(), | |||
| std::back_inserter(inputs), | |||
| [](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); }); | |||
| } else if (value->isa<Scalar>()) { | |||
| tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>()); | |||
| MS_EXCEPTION_IF_NULL(scalar_tensor); | |||
| inputs.push_back(scalar_tensor); | |||
| } else { | |||
| inputs.push_back(value->cast<tensor::TensorPtr>()); | |||
| } | |||
| } else if (utils::isa<PyObjectRef>(arg)) { | |||
| auto value = utils::cast<PyObjectRef>(arg).object_; | |||
| inputs.push_back(py::cast<tensor::TensorPtr>(value)); | |||
| } else if (utils::isa<VectorRefPtr>(arg)) { | |||
| auto args_new = utils::cast<VectorRef>(arg); | |||
| (void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs), | |||
| [](const BaseRef &v) { return utils::cast<tensor::TensorPtr>(v); }); | |||
| } else { | |||
| MS_LOG(WARNING) << "Invalid input type."; | |||
| } | |||
| PushInputTensor(arg, &inputs); | |||
| } | |||
| VectorRef outputs; | |||