| @@ -19,7 +19,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { | tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std::vector<int64_t> &shape) { | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | ||||
| size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | |||||
| size_t mem_size = IntToSize(tensor->ElementsNum()); | |||||
| auto tensor_data = tensor->data_c(); | auto tensor_data = tensor->data_c(); | ||||
| char *data = reinterpret_cast<char *>(tensor_data); | char *data = reinterpret_cast<char *>(tensor_data); | ||||
| MS_EXCEPTION_IF_NULL(data); | MS_EXCEPTION_IF_NULL(data); | ||||
| @@ -30,11 +30,11 @@ tensor::TensorPtr TensorConstructUtils::CreateZerosTensor(TypeId type, const std | |||||
| tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { | tensor::TensorPtr TensorConstructUtils::CreateOnesTensor(TypeId type, const std::vector<int64_t> &shape) { | ||||
| tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(type, shape); | ||||
| size_t mem_size = GetTypeByte(tensor->type()) * IntToSize(tensor->ElementsNum()); | |||||
| size_t mem_size = IntToSize(tensor->ElementsNum()); | |||||
| if (tensor->data_type() == kNumberTypeFloat32) { | if (tensor->data_type() == kNumberTypeFloat32) { | ||||
| SetTensorData(tensor->data_c(), 1.0, mem_size); | |||||
| SetTensorData<float>(tensor->data_c(), 1.0, mem_size); | |||||
| } else if (tensor->data_type() == kNumberTypeInt) { | } else if (tensor->data_type() == kNumberTypeInt) { | ||||
| SetTensorData(tensor->data_c(), 1, mem_size); | |||||
| SetTensorData<int>(tensor->data_c(), 1, mem_size); | |||||
| } | } | ||||
| return tensor; | return tensor; | ||||
| } | } | ||||