| @@ -447,6 +447,10 @@ TensorDataPtr MakeTensorData(TypeId data_type, const ShapeVector &shape, const A | |||||
| return std::make_shared<TensorDataImpl<float>>(shape, args...); | return std::make_shared<TensorDataImpl<float>>(shape, args...); | ||||
| case kNumberTypeFloat64: | case kNumberTypeFloat64: | ||||
| return std::make_shared<TensorDataImpl<double>>(shape, args...); | return std::make_shared<TensorDataImpl<double>>(shape, args...); | ||||
| case kObjectTypeString: | |||||
| return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...); | |||||
| case kObjectTypeTensorType: | |||||
| return std::make_shared<TensorDataImpl<int>>(shape, args...); | |||||
| default: | default: | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -549,8 +553,8 @@ Tensor &Tensor::AssignValue(const Tensor &tensor) { | |||||
| abstract::AbstractBasePtr Tensor::ToAbstract() { | abstract::AbstractBasePtr Tensor::ToAbstract() { | ||||
| auto tens = shared_from_base<Tensor>(); | auto tens = shared_from_base<Tensor>(); | ||||
| auto dtype = tens->Dtype(); | auto dtype = tens->Dtype(); | ||||
| if (!IsSubType(dtype, kNumber)) { | |||||
| MS_LOG(EXCEPTION) << "Expect tensor type kNumber but got: " << dtype->ToString() << "."; | |||||
| if (!IsSubType(dtype, kNumber) && !IsSubType(dtype, kString) && !IsSubType(dtype, kTensorType)) { | |||||
| MS_LOG(EXCEPTION) << "Expect tensor type kNumber or kString or kTensor but got: " << dtype->ToString() << "."; | |||||
| } | } | ||||
| auto tensor_shape = tens->shape(); | auto tensor_shape = tens->shape(); | ||||
| auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | auto abs_tensor = std::make_shared<abstract::AbstractTensor>(dtype, tensor_shape); | ||||