|
|
|
@@ -30,8 +30,6 @@ |
|
|
|
namespace mindspore { |
|
|
|
namespace tensor { |
|
|
|
|
|
|
|
using Bool = unsigned char; |
|
|
|
|
|
|
|
static std::string MakeId() { |
|
|
|
// Use atomic to make id generator thread safe. |
|
|
|
static std::atomic<uint64_t> last_id{1}; |
|
|
|
@@ -50,10 +48,7 @@ template <typename T> |
|
|
|
std::vector<T> CopyData(const std::vector<int> &shape, void *data, TypeId data_type) { |
|
|
|
const size_t count = SizeOf(shape); |
|
|
|
switch (data_type) { |
|
|
|
case kNumberTypeBool: { |
|
|
|
auto buf = static_cast<Bool *>(data); |
|
|
|
return std::vector<T>(buf, buf + count); |
|
|
|
} |
|
|
|
case kNumberTypeBool: |
|
|
|
case kNumberTypeUInt8: { |
|
|
|
auto buf = static_cast<uint8_t *>(data); |
|
|
|
return std::vector<T>(buf, buf + count); |
|
|
|
@@ -104,14 +99,6 @@ std::vector<T> CopyData(const std::vector<int> &shape, void *data, TypeId data_t |
|
|
|
MS_LOG(EXCEPTION) << "Cannot construct Tensor because of unsupported data type: " << data_type << "."; |
|
|
|
} |
|
|
|
|
|
|
|
// Convert to bool is not allowed. |
|
|
|
template <> |
|
|
|
std::vector<Bool> CopyData<Bool>(const std::vector<int> &shape, void *data, TypeId data_type) { |
|
|
|
MS_LOG(EXCEPTION) << "Cannot convert from " << TypeIdLabel(data_type) << " to " << TypeIdLabel(kNumberTypeBool) |
|
|
|
<< "."; |
|
|
|
return {}; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T> |
|
|
|
std::vector<T> CopyData(const std::vector<int> &shape, void *data, size_t data_len) { |
|
|
|
size_t size = SizeOf(shape); |
|
|
|
@@ -192,10 +179,6 @@ template <typename... Args> |
|
|
|
TensorDataPtr MakeTensorData(TypeId data_type, const std::vector<int> &shape, Args... args) { |
|
|
|
switch (data_type) { |
|
|
|
case kNumberTypeBool: |
|
|
|
// std::vector<bool> is a specialization of std::vector, |
|
|
|
// it may use single bit instead of sizeof(bool) bytes, |
|
|
|
// so we use std::vector<Bool> for bool tensors. |
|
|
|
return std::make_shared<TensorDataImpl<Bool>>(shape, args...); |
|
|
|
case kNumberTypeUInt8: |
|
|
|
return std::make_shared<TensorDataImpl<uint8_t>>(shape, args...); |
|
|
|
case kNumberTypeInt8: |
|
|
|
|