|
|
|
@@ -319,6 +319,27 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tenso |
|
|
|
input_tensors->push_back(tensor_ptr); |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertMultiPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, |
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) { |
|
|
|
MS_EXCEPTION_IF_NULL(op_prim); |
|
|
|
MS_EXCEPTION_IF_NULL(input_tensors); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_mask); |
|
|
|
|
|
|
|
if (!py::isinstance<py::tuple>(input_object)) { |
|
|
|
MS_LOG(EXCEPTION) << "The input should be a tuple!"; |
|
|
|
} |
|
|
|
auto tuple_inputs = py::cast<py::tuple>(input_object); |
|
|
|
if (tuple_inputs.size() == 0) { |
|
|
|
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!"; |
|
|
|
} |
|
|
|
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) { |
|
|
|
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); |
|
|
|
} else { |
|
|
|
ConvertValueTupleToTensor(input_object, input_tensors); |
|
|
|
*tensor_mask = kValueNodeTensorMask; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, |
|
|
|
std::vector<tensor::TensorPtr> *input_tensors, int *tensor_mask) { |
|
|
|
MS_EXCEPTION_IF_NULL(op_prim); |
|
|
|
@@ -333,20 +354,20 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr |
|
|
|
} else if (py::isinstance<py::int_>(input_object)) { |
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::int_>(input_object), kInt32); |
|
|
|
*tensor_mask = kValueNodeTensorMask; |
|
|
|
} else if (py::isinstance<py::list>(input_object)) { |
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::list>(input_object), nullptr); |
|
|
|
} else if (py::isinstance<py::array>(input_object)) { |
|
|
|
tensor_ptr = std::make_shared<tensor::Tensor>(py::cast<py::array>(input_object), nullptr); |
|
|
|
} else if (py::isinstance<py::none>(input_object)) { |
|
|
|
} else if (py::isinstance<py::list>(input_object)) { |
|
|
|
auto list_inputs = py::cast<py::list>(input_object); |
|
|
|
py::tuple tuple_inputs(list_inputs.size()); |
|
|
|
for (size_t i = 0; i < tuple_inputs.size(); ++i) { |
|
|
|
tuple_inputs[i] = list_inputs[i]; |
|
|
|
} |
|
|
|
ConvertMultiPyObjectToTensor(tuple_inputs, op_prim, input_tensors, tensor_mask); |
|
|
|
return; |
|
|
|
} else if (py::isinstance<py::tuple>(input_object)) { |
|
|
|
auto tuple_inputs = py::cast<py::tuple>(input_object); |
|
|
|
if (py::isinstance<tensor::Tensor>(tuple_inputs[0])) { |
|
|
|
PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); |
|
|
|
} else { |
|
|
|
ConvertValueTupleToTensor(input_object, input_tensors); |
|
|
|
*tensor_mask = kValueNodeTensorMask; |
|
|
|
} |
|
|
|
ConvertMultiPyObjectToTensor(input_object, op_prim, input_tensors, tensor_mask); |
|
|
|
return; |
|
|
|
} else if (py::isinstance<py::none>(input_object)) { |
|
|
|
return; |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; |
|
|
|
|