diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 528e4092..4be8930f 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -312,7 +312,7 @@ void ChannelImpl::dispatch_default_cpu( HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn); // use `put` for consistency auto info = reinterpret_cast(put_impl(host_tensornd, false)); - mgb_assert(info->desc.layout.ndim != 0); + mgb_assert(info->shape_valid()); output_infos.push_back(info); outputs->push_back(reinterpret_cast(info)); } @@ -406,7 +406,7 @@ SmallVector ChannelImpl::apply_op( MGB_LOCK_GUARD(m_spin); mgb_assert(check_available(), "Channel already closed"); auto* input = reinterpret_cast(inputs[0]); - if (op->same_type() && input->desc.layout.ndim) { + if (op->same_type() && input->shape_valid()) { size_t ndim = input->desc.layout.ndim; auto& gvs = op->cast_final_safe(); if (gvs.axis == MEGDNN_MAX_NDIM) { @@ -477,11 +477,11 @@ TensorShape ChannelImpl::get_shape(Handle handle) { m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", handle); auto info = reinterpret_cast(handle); - if (info->desc.layout.ndim != 0) { + if (info->shape_valid()) { return info->desc.layout; } TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout(); - mgb_assert(ret.ndim != 0); + mgb_assert(ret.ndim > 0); return ret; } @@ -694,12 +694,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), ptr->raw_ptr_not_for_readwrite()); // update tensor desc for static infer - if (dest->desc.layout.ndim) { - mgb_assert( - dest->desc.layout.eq_shape(ptr->layout()), - "shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(), - ptr->layout().to_string().c_str()); - } + dest->update_layout(ptr->layout()); // in order to avoid performance impact, // memory forwarding is disabled when DTR is enabled if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { diff --git a/imperative/src/impl/interpreter/tensor_info.h b/imperative/src/impl/interpreter/tensor_info.h index f8719436..1da086ee 100644 --- a/imperative/src/impl/interpreter/tensor_info.h +++ b/imperative/src/impl/interpreter/tensor_info.h @@ -48,6 +48,7 @@ struct TensorInfo { // Lock interpreter when visiting `ptr`. TensorPtr ptr; LogicalTensorDesc desc; + Spinlock lock; double compute_time; size_t memory; @@ -158,6 +159,26 @@ struct TensorInfo { // UINT_MAX as a magic default value size_t cand_index = UINT_MAX; + + bool shape_valid() { + MGB_LOCK_GUARD(lock); + return desc.layout.ndim; + } + + void update_layout(const TensorLayout& layout) { + MGB_LOCK_GUARD(lock); + mgb_assert(desc.layout.dtype == layout.dtype, "dtype mismatch"); + mgb_assert(desc.layout.format == layout.format, "format mismatch"); + if (desc.layout.ndim) { + mgb_assert( + desc.layout.eq_shape(layout), "shape infer error, %s vs %s", + desc.layout.to_string().c_str(), layout.to_string().c_str()); + // ignore strides + } else { + static_cast(desc.layout) = layout; + desc.layout.init_contiguous_stride(); + } + } }; } // namespace interpreter::intl