From 079b78f2e40363f6b622885a023b11d2a10fd9bb Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Thu, 4 Mar 2021 15:59:58 +0800 Subject: [PATCH] fix bug of construct tensor --- mindspore/core/utils/tensor_construct_utils.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mindspore/core/utils/tensor_construct_utils.cc b/mindspore/core/utils/tensor_construct_utils.cc index 21c93d9d79..1563ecc074 100644 --- a/mindspore/core/utils/tensor_construct_utils.cc +++ b/mindspore/core/utils/tensor_construct_utils.cc @@ -19,7 +19,7 @@ namespace mindspore { tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector &shape) { tensor::TensorPtr tensor = std::make_shared(type, shape); - size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); + size_t mem_size = IntToSize(tensor->ElementsNum()); auto tensor_data = tensor->data_c(); char *data = reinterpret_cast(tensor_data); MS_EXCEPTION_IF_NULL(data); @@ -30,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector &shape) { tensor::TensorPtr tensor = std::make_shared(type, shape); - size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); + size_t mem_size = IntToSize(tensor->ElementsNum()); if (tensor->data_type() == kNumberTypeFloat32) { - SetTensorData(tensor->data_c(), 1.0, mem_size); + SetTensorData(tensor->data_c(), 1.0, mem_size); } else if (tensor->data_type() == kNumberTypeInt) { - SetTensorData(tensor->data_c(), 1, mem_size); + SetTensorData(tensor->data_c(), 1, mem_size); } return tensor; }