|
|
|
@@ -107,6 +107,7 @@ bool Tensor::operator==(const Tensor &tensor) { |
|
|
|
} |
|
|
|
|
|
|
|
int32_t Tensor::Batch() const { |
|
|
|
// Only 2D or 4D tensors have valid batch. |
|
|
|
if (this->shape_.size() != 4 && this->shape_.size() != 2) { |
|
|
|
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); |
|
|
|
return RET_ERROR; |
|
|
|
@@ -123,8 +124,14 @@ int32_t Tensor::Batch() const { |
|
|
|
return this->shape_[0]; |
|
|
|
case mindspore::HWCK: |
|
|
|
case mindspore::CHWK: |
|
|
|
if (this->shape_.size() != 4) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return this->shape_[3]; |
|
|
|
case mindspore::HWKC: |
|
|
|
if (this->shape_.size() != 4) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return this->shape_[2]; |
|
|
|
case mindspore::CKHW: |
|
|
|
return this->shape_[1]; |
|
|
|
@@ -135,6 +142,7 @@ int32_t Tensor::Batch() const { |
|
|
|
} |
|
|
|
|
|
|
|
int32_t Tensor::Channel() const { |
|
|
|
// Only 2D or 4D tensors have valid channel. |
|
|
|
if (this->shape_.size() != 4 && this->shape_.size() != 2) { |
|
|
|
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); |
|
|
|
return RET_ERROR; |
|
|
|
@@ -146,12 +154,18 @@ int32_t Tensor::Channel() const { |
|
|
|
case mindspore::NC4: |
|
|
|
return this->shape_[1]; |
|
|
|
case mindspore::HWCK: |
|
|
|
if (this->shape_.size() != 4) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return this->shape_[2]; |
|
|
|
case mindspore::HWKC: |
|
|
|
case mindspore::NHWC: |
|
|
|
case mindspore::NHWC4: |
|
|
|
case mindspore::NC4HW4: |
|
|
|
case mindspore::KHWC: |
|
|
|
if (this->shape_.size() != 4) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return this->shape_[3]; |
|
|
|
case mindspore::CKHW: |
|
|
|
case mindspore::CHWK: |
|
|
|
@@ -162,6 +176,7 @@ int32_t Tensor::Channel() const { |
|
|
|
} |
|
|
|
|
|
|
|
int32_t Tensor::Height() const { |
|
|
|
// Only 2D or 4D tensors have valid height. |
|
|
|
if (this->shape_.size() != 4 && this->shape_.size() != 2) { |
|
|
|
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); |
|
|
|
return RET_ERROR; |
|
|
|
@@ -170,6 +185,9 @@ int32_t Tensor::Height() const { |
|
|
|
case mindspore::NCHW: |
|
|
|
case mindspore::KCHW: |
|
|
|
case mindspore::CKHW: |
|
|
|
if (this->shape_.size() != 4) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return this->shape_[2]; |
|
|
|
case mindspore::NHWC: |
|
|
|
case mindspore::NHWC4: |
|
|
|
@@ -189,6 +207,7 @@ int32_t Tensor::Height() const { |
|
|
|
} |
|
|
|
|
|
|
|
int32_t Tensor::Width() const { |
|
|
|
// Only 2D or 4D tensors have valid width. |
|
|
|
if (this->shape_.size() != 4 && this->shape_.size() != 2) { |
|
|
|
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size(); |
|
|
|
return RET_ERROR; |
|
|
|
@@ -197,12 +216,18 @@ int32_t Tensor::Width() const { |
|
|
|
case mindspore::NCHW: |
|
|
|
case mindspore::KCHW: |
|
|
|
case mindspore::CKHW: |
|
|
|
if (this->shape_.size() != 4) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return this->shape_[3]; |
|
|
|
case mindspore::KHWC: |
|
|
|
case mindspore::NHWC: |
|
|
|
case mindspore::NHWC4: |
|
|
|
case mindspore::NC4HW4: |
|
|
|
case mindspore::CHWK: |
|
|
|
if (this->shape_.size() != 4) { |
|
|
|
return RET_ERROR; |
|
|
|
} |
|
|
|
return this->shape_[2]; |
|
|
|
case mindspore::HWCK: |
|
|
|
case mindspore::HWKC: |
|
|
|
|