|
|
|
@@ -91,37 +91,48 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) { |
|
|
|
return VectorRef(outputs); |
|
|
|
} |
|
|
|
|
|
|
|
namespace { |
|
|
|
void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) { |
|
|
|
MS_EXCEPTION_IF_NULL(inputs); |
|
|
|
if (utils::isa<tensor::TensorPtr>(arg)) { |
|
|
|
auto value = utils::cast<tensor::TensorPtr>(arg); |
|
|
|
inputs->push_back(value); |
|
|
|
} else if (utils::isa<ValuePtr>(arg)) { |
|
|
|
auto value = utils::cast<ValuePtr>(arg); |
|
|
|
MS_EXCEPTION_IF_NULL(value); |
|
|
|
if (value->isa<ValueTuple>()) { |
|
|
|
auto value_tuple = value->cast<ValueTuplePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(value_tuple); |
|
|
|
auto tuple_value = value_tuple->value(); |
|
|
|
for (const auto &v : tuple_value) { |
|
|
|
PushInputTensor(v, inputs); |
|
|
|
} |
|
|
|
} else if (value->isa<Scalar>()) { |
|
|
|
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>()); |
|
|
|
inputs->push_back(scalar_tensor); |
|
|
|
} else { |
|
|
|
inputs->push_back(value->cast<tensor::TensorPtr>()); |
|
|
|
} |
|
|
|
} else if (utils::isa<PyObjectRef>(arg)) { |
|
|
|
auto value = utils::cast<PyObjectRef>(arg).object_; |
|
|
|
inputs->push_back(py::cast<tensor::TensorPtr>(value)); |
|
|
|
} else if (utils::isa<VectorRefPtr>(arg)) { |
|
|
|
const auto &args_new = utils::cast<VectorRef>(arg); |
|
|
|
for (const auto &v : args_new) { |
|
|
|
PushInputTensor(v, inputs); |
|
|
|
} |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "Invalid input type."; |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace |
|
|
|
|
|
|
|
VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { |
|
|
|
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; |
|
|
|
// Run graph |
|
|
|
std::vector<tensor::TensorPtr> inputs; |
|
|
|
for (const auto &arg : args) { |
|
|
|
if (utils::isa<tensor::TensorPtr>(arg)) { |
|
|
|
auto value = utils::cast<tensor::TensorPtr>(arg); |
|
|
|
inputs.push_back(value); |
|
|
|
} else if (utils::isa<ValuePtr>(arg)) { |
|
|
|
auto value = utils::cast<ValuePtr>(arg); |
|
|
|
if (value->isa<ValueTuple>()) { |
|
|
|
(void)std::transform(value->cast<ValueTuplePtr>()->value().begin(), value->cast<ValueTuplePtr>()->value().end(), |
|
|
|
std::back_inserter(inputs), |
|
|
|
[](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); }); |
|
|
|
} else if (value->isa<Scalar>()) { |
|
|
|
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>()); |
|
|
|
MS_EXCEPTION_IF_NULL(scalar_tensor); |
|
|
|
inputs.push_back(scalar_tensor); |
|
|
|
} else { |
|
|
|
inputs.push_back(value->cast<tensor::TensorPtr>()); |
|
|
|
} |
|
|
|
} else if (utils::isa<PyObjectRef>(arg)) { |
|
|
|
auto value = utils::cast<PyObjectRef>(arg).object_; |
|
|
|
inputs.push_back(py::cast<tensor::TensorPtr>(value)); |
|
|
|
} else if (utils::isa<VectorRefPtr>(arg)) { |
|
|
|
auto args_new = utils::cast<VectorRef>(arg); |
|
|
|
(void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs), |
|
|
|
[](const BaseRef &v) { return utils::cast<tensor::TensorPtr>(v); }); |
|
|
|
} else { |
|
|
|
MS_LOG(WARNING) << "Invalid input type."; |
|
|
|
} |
|
|
|
PushInputTensor(arg, &inputs); |
|
|
|
} |
|
|
|
|
|
|
|
VectorRef outputs; |
|
|
|
|