|
|
|
@@ -19,7 +19,7 @@ |
|
|
|
namespace mindspore { |
|
|
|
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { |
|
|
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); |
|
|
|
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); |
|
|
|
size_t mem_size = IntToSize(tensor->ElementsNum()); |
|
|
|
auto tensor_data = tensor->data_c(); |
|
|
|
char *data = reinterpret_cast<char *>(tensor_data); |
|
|
|
MS_EXCEPTION_IF_NULL(data); |
|
|
|
@@ -30,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std |
|
|
|
|
|
|
|
tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { |
|
|
|
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); |
|
|
|
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); |
|
|
|
size_t mem_size = IntToSize(tensor->ElementsNum()); |
|
|
|
if (tensor->data_type() == kNumberTypeFloat32) { |
|
|
|
SetTensorData(tensor->data_c(), 1.0, mem_size); |
|
|
|
SetTensorData<float>(tensor->data_c(), 1.0, mem_size); |
|
|
|
} else if (tensor->data_type() == kNumberTypeInt) { |
|
|
|
SetTensorData(tensor->data_c(), 1, mem_size); |
|
|
|
SetTensorData<int>(tensor->data_c(), 1, mem_size); |
|
|
|
} |
|
|
|
return tensor; |
|
|
|
} |
|
|
|
|