diff --git a/imperative/src/include/megbrain/imperative/utils/value_shape.h b/imperative/src/include/megbrain/imperative/utils/value_shape.h index a9b60bf0..c52584c2 100644 --- a/imperative/src/include/megbrain/imperative/utils/value_shape.h +++ b/imperative/src/include/megbrain/imperative/utils/value_shape.h @@ -23,8 +23,7 @@ namespace mgb::imperative { * */ struct ValueShape { - size_t shape[TensorShape::MAX_NDIM]; - int ndim = 0; + size_t shape[TensorShape::MAX_NDIM], ndim = 0; ValueShape() = default; ValueShape(std::initializer_list dims) { @@ -70,19 +69,14 @@ struct ValueShape { return buffer; } - static ValueShape from(TensorShape tensor_shape) { + static const ValueShape& from(const TensorShape& tensor_shape) { mgb_assert(tensor_shape.ndim); - return Span{tensor_shape.shape, tensor_shape.ndim}; + return reinterpret_cast(tensor_shape); } - TensorShape as_tensor_shape() const { + const TensorShape& as_tensor_shape() const { mgb_assert(ndim != 0); - TensorShape ret; - for (size_t i = 0; i < ndim; ++i) { - ret.shape[i] = shape[i]; - } - ret.ndim = ndim; - return ret; + return reinterpret_cast(*this); } bool operator==(const ValueShape& rhs) const {