Browse Source

fix bug of construct tensor

tags/v1.2.0-rc1
LianLiguang 4 years ago
parent
commit
079b78f2e4
1 changed files with 4 additions and 4 deletions
  1. +4
    -4
      mindspore/core/utils/tensor_construct_utils.cc

+ 4
- 4
mindspore/core/utils/tensor_construct_utils.cc View File

@@ -19,7 +19,7 @@
namespace mindspore {
tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
size_t mem_size = IntToSize(tensor->ElementsNum());
auto tensor_data = tensor->data_c();
char *data = reinterpret_cast<char *>(tensor_data);
MS_EXCEPTION_IF_NULL(data);
@@ -30,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std

tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) {
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape);
size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum());
size_t mem_size = IntToSize(tensor->ElementsNum());
if (tensor->data_type() == kNumberTypeFloat32) {
SetTensorData(tensor->data_c(), 1.0, mem_size);
SetTensorData<float>(tensor->data_c(), 1.0, mem_size);
} else if (tensor->data_type() == kNumberTypeInt) {
SetTensorData(tensor->data_c(), 1, mem_size);
SetTensorData<int>(tensor->data_c(), 1, mem_size);
}
return tensor;
}


Loading…
Cancel
Save