Browse Source

!4512 Expand the vectors for tensor inputs recursively

Merge pull request !4512 from YuJianfeng/internal_output
tags/v0.7.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
19c800a758
1 changed files with 37 additions and 26 deletions
  1. +37
    -26
      mindspore/ccsrc/vm/backend.cc

+ 37
- 26
mindspore/ccsrc/vm/backend.cc View File

@@ -91,37 +91,48 @@ VectorRef MsBackend::MsSimuRunGraph(const GraphId &g, const VectorRef &args) {
return VectorRef(outputs); return VectorRef(outputs);
} }


namespace {
void PushInputTensor(const BaseRef &arg, std::vector<tensor::TensorPtr> *inputs) {
MS_EXCEPTION_IF_NULL(inputs);
if (utils::isa<tensor::TensorPtr>(arg)) {
auto value = utils::cast<tensor::TensorPtr>(arg);
inputs->push_back(value);
} else if (utils::isa<ValuePtr>(arg)) {
auto value = utils::cast<ValuePtr>(arg);
MS_EXCEPTION_IF_NULL(value);
if (value->isa<ValueTuple>()) {
auto value_tuple = value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
auto tuple_value = value_tuple->value();
for (const auto &v : tuple_value) {
PushInputTensor(v, inputs);
}
} else if (value->isa<Scalar>()) {
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
inputs->push_back(scalar_tensor);
} else {
inputs->push_back(value->cast<tensor::TensorPtr>());
}
} else if (utils::isa<PyObjectRef>(arg)) {
auto value = utils::cast<PyObjectRef>(arg).object_;
inputs->push_back(py::cast<tensor::TensorPtr>(value));
} else if (utils::isa<VectorRefPtr>(arg)) {
const auto &args_new = utils::cast<VectorRef>(arg);
for (const auto &v : args_new) {
PushInputTensor(v, inputs);
}
} else {
MS_LOG(WARNING) << "Invalid input type.";
}
}
} // namespace

VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) {
MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g; MS_LOG(DEBUG) << "start ms graph run:" << args.size() << ", g:" << g;
// Run graph // Run graph
std::vector<tensor::TensorPtr> inputs; std::vector<tensor::TensorPtr> inputs;
for (const auto &arg : args) { for (const auto &arg : args) {
if (utils::isa<tensor::TensorPtr>(arg)) {
auto value = utils::cast<tensor::TensorPtr>(arg);
inputs.push_back(value);
} else if (utils::isa<ValuePtr>(arg)) {
auto value = utils::cast<ValuePtr>(arg);
if (value->isa<ValueTuple>()) {
(void)std::transform(value->cast<ValueTuplePtr>()->value().begin(), value->cast<ValueTuplePtr>()->value().end(),
std::back_inserter(inputs),
[](const ValuePtr &v) { return v->cast<tensor::TensorPtr>(); });
} else if (value->isa<Scalar>()) {
tensor::TensorPtr scalar_tensor = ScalarToTensor(value->cast<ScalarPtr>());
MS_EXCEPTION_IF_NULL(scalar_tensor);
inputs.push_back(scalar_tensor);
} else {
inputs.push_back(value->cast<tensor::TensorPtr>());
}
} else if (utils::isa<PyObjectRef>(arg)) {
auto value = utils::cast<PyObjectRef>(arg).object_;
inputs.push_back(py::cast<tensor::TensorPtr>(value));
} else if (utils::isa<VectorRefPtr>(arg)) {
auto args_new = utils::cast<VectorRef>(arg);
(void)std::transform(args_new.begin(), args_new.end(), std::back_inserter(inputs),
[](const BaseRef &v) { return utils::cast<tensor::TensorPtr>(v); });
} else {
MS_LOG(WARNING) << "Invalid input type.";
}
PushInputTensor(arg, &inputs);
} }


VectorRef outputs; VectorRef outputs;


Loading…
Cancel
Save