|
|
|
@@ -53,8 +53,14 @@ tensor::TensorPtr TensorConstructUtils::CreateTensor(const TypePtr type_ptr, con |
|
|
|
|
|
|
|
TypeId TensorConstructUtils::ExtractTypeId(const TypePtr type_ptr) { |
|
|
|
MS_EXCEPTION_IF_NULL(type_ptr); |
|
|
|
auto tensor_type = type_ptr->cast<TensorTypePtr>(); |
|
|
|
auto type_id = tensor_type->element()->type_id(); |
|
|
|
TypeId type_id; |
|
|
|
if (type_ptr->isa<TensorType>()) { |
|
|
|
auto tensor_type = type_ptr->cast<TensorTypePtr>(); |
|
|
|
MS_EXCEPTION_IF_NULL(tensor_type); |
|
|
|
type_id = tensor_type->element()->type_id(); |
|
|
|
} else { |
|
|
|
type_id = type_ptr->type_id(); |
|
|
|
} |
|
|
|
return type_id; |
|
|
|
} |
|
|
|
} // namespace mindspore |