|
|
|
@@ -50,55 +50,68 @@ py::object TensorToPyData(const tensor::TensorPtr &tensor) { |
|
|
|
return v[0]; |
|
|
|
} |
|
|
|
|
|
|
|
py::object ScalarPtrToPyData(const ScalarPtr &value) { |
|
|
|
py::int_ int_v; |
|
|
|
py::float_ float_v; |
|
|
|
py::bool_ bool_v; |
|
|
|
TypeId scalar_type = value->type()->type_id(); |
|
|
|
switch (scalar_type) { |
|
|
|
case kNumberTypeUInt8: |
|
|
|
MS_LOG(DEBUG) << "uint8"; |
|
|
|
int_v = value->cast<UInt8ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeUInt16: |
|
|
|
MS_LOG(DEBUG) << "uint16"; |
|
|
|
int_v = value->cast<UInt16ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeUInt32: |
|
|
|
MS_LOG(DEBUG) << "uint32"; |
|
|
|
int_v = value->cast<UInt32ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeUInt64: |
|
|
|
MS_LOG(DEBUG) << "uint64"; |
|
|
|
int_v = value->cast<UInt64ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeInt8: |
|
|
|
MS_LOG(DEBUG) << "int8"; |
|
|
|
int_v = value->cast<Int8ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeInt16: |
|
|
|
MS_LOG(DEBUG) << "int16"; |
|
|
|
int_v = value->cast<Int16ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeInt32: |
|
|
|
MS_LOG(DEBUG) << "int32"; |
|
|
|
int_v = value->cast<Int32ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeInt64: |
|
|
|
MS_LOG(DEBUG) << "int64"; |
|
|
|
int_v = value->cast<Int64ImmPtr>()->value(); |
|
|
|
return std::move(int_v); |
|
|
|
case kNumberTypeFloat32: |
|
|
|
MS_LOG(DEBUG) << "float"; |
|
|
|
float_v = value->cast<FP32ImmPtr>()->value(); |
|
|
|
return std::move(float_v); |
|
|
|
case kNumberTypeFloat64: |
|
|
|
MS_LOG(DEBUG) << "double"; |
|
|
|
float_v = value->cast<FP64ImmPtr>()->value(); |
|
|
|
return std::move(float_v); |
|
|
|
case kNumberTypeBool: |
|
|
|
MS_LOG(DEBUG) << "bool"; |
|
|
|
bool_v = value->cast<BoolImmPtr>()->value(); |
|
|
|
return std::move(bool_v); |
|
|
|
default: |
|
|
|
MS_EXCEPTION(TypeError) << "Unsupported scalar converted to py data: " << value->ToString(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
py::object ValuePtrToPyData(const ValuePtr &value) { |
|
|
|
if (value == nullptr) { |
|
|
|
MS_LOG(EXCEPTION) << "value is null"; |
|
|
|
} |
|
|
|
py::object ret; |
|
|
|
if (value->isa<Int8Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "int8"; |
|
|
|
py::int_ v = value->cast<Int8ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<Int16Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "int16"; |
|
|
|
py::int_ v = value->cast<Int16ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<Int32Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "int32"; |
|
|
|
py::int_ v = value->cast<Int32ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<Int64Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "int64"; |
|
|
|
py::int_ v = value->cast<Int64ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<UInt8Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "uint8"; |
|
|
|
py::int_ v = value->cast<UInt8ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<UInt16Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "uint16"; |
|
|
|
py::int_ v = value->cast<UInt16ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<UInt32Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "uint32"; |
|
|
|
py::int_ v = value->cast<UInt32ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<UInt64Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "uint64"; |
|
|
|
py::int_ v = value->cast<UInt64ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<BoolImm>()) { |
|
|
|
MS_LOG(DEBUG) << "bool"; |
|
|
|
py::bool_ v = value->cast<BoolImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<FP64Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "double"; |
|
|
|
py::float_ v = value->cast<FP64ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
} else if (value->isa<FP32Imm>()) { |
|
|
|
MS_LOG(DEBUG) << "float"; |
|
|
|
py::float_ v = value->cast<FP32ImmPtr>()->value(); |
|
|
|
ret = v; |
|
|
|
if (value->isa<Scalar>()) { |
|
|
|
ret = ScalarPtrToPyData(value->cast<ScalarPtr>()); |
|
|
|
} else if (value->isa<StringImm>()) { |
|
|
|
MS_LOG(DEBUG) << "String"; |
|
|
|
py::str v = value->cast<StringImmPtr>()->value(); |
|
|
|
@@ -117,28 +130,27 @@ py::object ValuePtrToPyData(const ValuePtr &value) { |
|
|
|
py::tuple v(1); |
|
|
|
v[0] = value->cast<RefKeyPtr>(); |
|
|
|
ret = v[0]; |
|
|
|
} else if (value->isa<ValueTuple>()) { |
|
|
|
MS_LOG(DEBUG) << "tuple"; |
|
|
|
auto value_tuple = value->cast<ValueTuplePtr>()->value(); |
|
|
|
py::tuple rets(value_tuple.size()); |
|
|
|
} else if (value->isa<ValueSequeue>()) { |
|
|
|
MS_LOG(DEBUG) << "tuple or list"; |
|
|
|
auto value_sequeue = value->cast<ValueSequeuePtr>()->value(); |
|
|
|
py::tuple ret_sequeue(value_sequeue.size()); |
|
|
|
|
|
|
|
size_t i = 0; |
|
|
|
for (auto &v : value_tuple) { |
|
|
|
rets[i] = ValuePtrToPyData(v); |
|
|
|
i++; |
|
|
|
for (size_t i = 0; i < value_sequeue.size(); i++) { |
|
|
|
ret_sequeue[i] = ValuePtrToPyData(value_sequeue[i]); |
|
|
|
} |
|
|
|
ret = rets; |
|
|
|
} else if (value->isa<ValueList>()) { |
|
|
|
MS_LOG(DEBUG) << "list"; |
|
|
|
auto value_list = value->cast<ValueListPtr>()->value(); |
|
|
|
py::list rets(value_list.size()); |
|
|
|
|
|
|
|
size_t i = 0; |
|
|
|
for (auto &v : value_list) { |
|
|
|
rets[i] = ValuePtrToPyData(v); |
|
|
|
i++; |
|
|
|
if (value->isa<ValueTuple>()) { |
|
|
|
ret = ret_sequeue; |
|
|
|
} else { |
|
|
|
ret = ret_sequeue.cast<py::list>(); |
|
|
|
} |
|
|
|
ret = rets; |
|
|
|
} else if (value->isa<ValueDictionary>()) { |
|
|
|
MS_LOG(DEBUG) << "dict"; |
|
|
|
auto value_list = value->cast<ValueDictionaryPtr>()->value(); |
|
|
|
py::dict ret_dict; |
|
|
|
for (const auto &v : value_list) { |
|
|
|
ret_dict[py::str(v.first)] = ValuePtrToPyData(v.second); |
|
|
|
} |
|
|
|
ret = ret_dict; |
|
|
|
} else if (value->isa<Ellipsis>()) { |
|
|
|
ret = py::ellipsis(); |
|
|
|
} else if (value->isa<ValueSlice>()) { |
|
|
|
@@ -152,15 +164,9 @@ py::object ValuePtrToPyData(const ValuePtr &value) { |
|
|
|
py::tuple v(1); |
|
|
|
v[0] = value->cast<TypePtr>(); |
|
|
|
ret = v[0]; |
|
|
|
} else if (value->isa<AnyValue>()) { |
|
|
|
ret = py::none(); |
|
|
|
} else if (value->isa<None>()) { |
|
|
|
ret = py::none(); |
|
|
|
} else if (value->isa<FuncGraph>()) { |
|
|
|
} else if (value->isa<AnyValue>() || value->isa<None>() || value->isa<Monad>() || value->isa<FuncGraph>()) { |
|
|
|
// FuncGraph is not used in the backend, return None |
|
|
|
ret = py::none(); |
|
|
|
} else if (value->isa<Monad>()) { |
|
|
|
ret = py::none(); |
|
|
|
} else { |
|
|
|
MS_LOG(EXCEPTION) << "Unsupported convert value: " << value->ToString() << " to a PyData."; |
|
|
|
} |
|
|
|
|