|
|
|
@@ -312,7 +312,7 @@ void ChannelImpl::dispatch_default_cpu( |
|
|
|
HostTensorND::make_proxy(tensornd).proxy_to_comp_node(output_cn); |
|
|
|
// use `put` for consistency |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(put_impl(host_tensornd, false)); |
|
|
|
mgb_assert(info->desc.layout.ndim != 0); |
|
|
|
mgb_assert(info->shape_valid()); |
|
|
|
output_infos.push_back(info); |
|
|
|
outputs->push_back(reinterpret_cast<Handle>(info)); |
|
|
|
} |
|
|
|
@@ -406,7 +406,7 @@ SmallVector<Handle> ChannelImpl::apply_op( |
|
|
|
MGB_LOCK_GUARD(m_spin); |
|
|
|
mgb_assert(check_available(), "Channel already closed"); |
|
|
|
auto* input = reinterpret_cast<TensorInfo*>(inputs[0]); |
|
|
|
if (op->same_type<GetVarShape>() && input->desc.layout.ndim) { |
|
|
|
if (op->same_type<GetVarShape>() && input->shape_valid()) { |
|
|
|
size_t ndim = input->desc.layout.ndim; |
|
|
|
auto& gvs = op->cast_final_safe<GetVarShape>(); |
|
|
|
if (gvs.axis == MEGDNN_MAX_NDIM) { |
|
|
|
@@ -477,11 +477,11 @@ TensorShape ChannelImpl::get_shape(Handle handle) { |
|
|
|
m_valid_handle.find(handle) != m_valid_handle.end(), "invalid handle: %p", |
|
|
|
handle); |
|
|
|
auto info = reinterpret_cast<TensorInfo*>(handle); |
|
|
|
if (info->desc.layout.ndim != 0) { |
|
|
|
if (info->shape_valid()) { |
|
|
|
return info->desc.layout; |
|
|
|
} |
|
|
|
TensorShape ret = wait_tensor(info, TensorProp::Shape)->layout(); |
|
|
|
mgb_assert(ret.ndim != 0); |
|
|
|
mgb_assert(ret.ndim > 0); |
|
|
|
return ret; |
|
|
|
} |
|
|
|
|
|
|
|
@@ -694,12 +694,7 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { |
|
|
|
TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), |
|
|
|
ptr->raw_ptr_not_for_readwrite()); |
|
|
|
// update tensor desc for static infer |
|
|
|
if (dest->desc.layout.ndim) { |
|
|
|
mgb_assert( |
|
|
|
dest->desc.layout.eq_shape(ptr->layout()), |
|
|
|
"shape infer error, %s vs %s", dest->desc.layout.to_string().c_str(), |
|
|
|
ptr->layout().to_string().c_str()); |
|
|
|
} |
|
|
|
dest->update_layout(ptr->layout()); |
|
|
|
// in order to avoid performance impact, |
|
|
|
// memory forwarding is disabled when DTR is enabled |
|
|
|
if (state.options.enable_dtr_auto_drop || state.options.disable_memory_forwarding) { |
|
|
|
|