diff --git a/mindspore/core/utils/tensor_construct_utils.cc b/mindspore/core/utils/tensor_construct_utils.cc index 21c93d9d79..1563ecc074 100644 --- a/mindspore/core/utils/tensor_construct_utils.cc +++ b/mindspore/core/utils/tensor_construct_utils.cc @@ -19,7 +19,7 @@ namespace mindspore { tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector &shape) { tensor::TensorPtr tensor = std::make_shared(type, shape); - size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); + size_t mem_size = IntToSize(tensor->ElementsNum()); auto tensor_data = tensor->data_c(); char *data = reinterpret_cast(tensor_data); MS_EXCEPTION_IF_NULL(data); @@ -30,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector &shape) { tensor::TensorPtr tensor = std::make_shared(type, shape); - size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); + size_t mem_size = IntToSize(tensor->ElementsNum()); if (tensor->data_type() == kNumberTypeFloat32) { - SetTensorData(tensor->data_c(), 1.0, mem_size); + SetTensorData(tensor->data_c(), 1.0, mem_size); } else if (tensor->data_type() == kNumberTypeInt) { - SetTensorData(tensor->data_c(), 1, mem_size); + SetTensorData(tensor->data_c(), 1, mem_size); } return tensor; }