|
|
|
@@ -447,6 +447,10 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A |
|
|
|
return std::make_shared<TensorDataImpl<float>>(shape, args...); |
|
|
|
case kNumberTypeFloat64: |
|
|
|
return std::make_shared<TensorDataImpl<double>>(shape, args...); |
|
|
|
case kObjectTypeString: |
|
|
|
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...); |
|
|
|
case kObjectTypeTensorType: |
|
|
|
return std::make_shared<TensorDataImpl<int>>(shape, args...); |
|
|
|
default: |
|
|
|
break; |
|
|
|
} |
|
|
|
@@ -549,8 +553,8 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) { |
|
|
|
abstract::AbstractBasePtr Tensor::ToAbstract() { |
|
|
|
auto tens = shared_from_base<Tensor>(); |
|
|
|
auto dtype = tens->Dtype(); |
|
|
|
if (!IsSubType(dtype, kNumber)) { |
|
|
|
MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << "."; |
|
|
|
if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) { |
|
|
|
MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << "."; |
|
|
|
} |
|
|
|
auto tensor_shape = tens->shape(); |
|
|
|
auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); |
|
|
|
|