Browse Source

!1765 Pynative support list input

Merge pull request !1765 from JoyLvliang/pynative-support-list-input
tags/v0.5.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
82b6c83653
1 changed files with 31 additions and 10 deletions
  1. +31
    -10
      mindspore/ccsrc/pynative/pynative_execute.cc

+ 31
- 10
mindspore/ccsrc/pynative/pynative_execute.cc View File

@@ -319,6 +319,27 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tenso
input_tensors->push_back(tensor_ptr);
}

void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
MS_EXCEPTION_IF_NULL(op_prim);
MS_EXCEPTION_IF_NULL(input_tensors);
MS_EXCEPTION_IF_NULL(tensor_mask);

if (!py::isinstance<py::tuple>(input_object)) {
MS_LOG(EXCEPTION) << "The input should be a tuple!";
}
auto tuple_inputs = py::cast<py::tuple>(input_object);
if (tuple_inputs.size() == 0) {
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
}
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
} else {
ConvertValueTupleToTensor(input_object, input_tensors);
*tensor_mask = kValueNodeTensorMask;
}
}

void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim,
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) {
MS_EXCEPTION_IF_NULL(op_prim);
@@ -333,20 +354,20 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr
} else if (py::isinstance<py::int_>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32);
*tensor_mask = kValueNodeTensorMask;
} else if (py::isinstance<py::list>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr);
} else if (py::isinstance<py::array>(input_object)) {
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr);
} else if (py::isinstance<py::none>(input_object)) {
} else if (py::isinstance<py::list>(input_object)) {
auto list_inputs = py::cast<py::list>(input_object);
py::tuple tuple_inputs(list_inputs.size());
for (size_t i = 0; i < tuple_inputs.size(); ++i) {
tuple_inputs[i] = list_inputs[i];
}
ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask);
return;
} else if (py::isinstance<py::tuple>(input_object)) {
auto tuple_inputs = py::cast<py::tuple>(input_object);
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) {
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors);
} else {
ConvertValueTupleToTensor(input_object, input_tensors);
*tensor_mask = kValueNodeTensorMask;
}
ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask);
return;
} else if (py::isinstance<py::none>(input_object)) {
return;
} else {
MS_LOG(EXCEPTION) << "Run op inputs type is invalid!";


Loading…
Cancel
Save