GitOrigin-RevId: 2663504470
tags/v1.3.0
| @@ -29,7 +29,7 @@ Interpreter& Interpreter::inst() { | |||||
| return inst_; | return inst_; | ||||
| } | } | ||||
| void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
| Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
| auto info = alloc(); | auto info = alloc(); | ||||
| info->desc.layout = value.layout(); | info->desc.layout = value.layout(); | ||||
| info->desc.comp_node = value.comp_node(); | info->desc.comp_node = value.comp_node(); | ||||
| @@ -39,7 +39,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||||
| return info; | return info; | ||||
| } | } | ||||
| void* ChannelImpl::put(const DeviceTensorND& data) { | |||||
| Handle ChannelImpl::put(const DeviceTensorND& data) { | |||||
| auto info = alloc(); | auto info = alloc(); | ||||
| info->desc.layout = data.layout(); | info->desc.layout = data.layout(); | ||||
| info->desc.comp_node = data.comp_node(); | info->desc.comp_node = data.comp_node(); | ||||
| @@ -48,12 +48,12 @@ void* ChannelImpl::put(const DeviceTensorND& data) { | |||||
| return info; | return info; | ||||
| } | } | ||||
| void ChannelImpl::del(void* handle) { | |||||
| void ChannelImpl::del(Handle handle) { | |||||
| mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | ||||
| m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | ||||
| } | } | ||||
| void ChannelImpl::swap_in(void* handle) { | |||||
| void ChannelImpl::swap_in(Handle handle) { | |||||
| if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| @@ -61,7 +61,7 @@ void ChannelImpl::swap_in(void* handle) { | |||||
| } | } | ||||
| } | } | ||||
| void ChannelImpl::swap_out(void* handle) { | |||||
| void ChannelImpl::swap_out(Handle handle) { | |||||
| if (m_enable_evict & SWAP) { | if (m_enable_evict & SWAP) { | ||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| @@ -69,7 +69,7 @@ void ChannelImpl::swap_out(void* handle) { | |||||
| } | } | ||||
| } | } | ||||
| void ChannelImpl::drop(void* handle) { | |||||
| void ChannelImpl::drop(Handle handle) { | |||||
| if (m_enable_evict & DROP) { | if (m_enable_evict & DROP) { | ||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| @@ -77,45 +77,91 @@ void ChannelImpl::drop(void* handle) { | |||||
| } | } | ||||
| } | } | ||||
| SmallVector<void*> ChannelImpl::apply_op( | |||||
| void ChannelImpl::dispatch_default_cpu( | |||||
| std::shared_ptr<OpDef> op, | std::shared_ptr<OpDef> op, | ||||
| const SmallVector<void*>& inputs) { | |||||
| for (auto i : inputs) { | |||||
| mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||||
| "invalid handle: %p", i); | |||||
| } | |||||
| SmallVector<TensorInfo*> input_infos; | |||||
| input_infos.reserve(inputs.size()); | |||||
| SmallVector<LogicalTensorDesc> input_descs; | |||||
| input_descs.reserve(inputs.size()); | |||||
| const SmallVector<TensorInfo*>& input_infos, | |||||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||||
| SmallVector<Handle>* outputs) { | |||||
| auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||||
| SmallVector<DeviceTensorND> input_tensornds; | |||||
| input_tensornds.reserve(input_descs.size()); | |||||
| CompNode output_cn; | |||||
| { | { | ||||
| MGB_LOCK_GUARD(m_mutex); | MGB_LOCK_GUARD(m_mutex); | ||||
| for (auto i : inputs) { | |||||
| auto info = reinterpret_cast<TensorInfo*>(i); | |||||
| mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
| input_infos.push_back(info); | |||||
| input_descs.push_back(info->desc); | |||||
| for (auto&& info : input_infos) { | |||||
| mgb_assert(info->ptr, "invalid tensor ptr!"); | |||||
| if (!output_cn.valid()) { | |||||
| output_cn = info->ptr->comp_node(); | |||||
| } else { | |||||
| mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); | |||||
| } | |||||
| mgb_assert(info->ptr->try_get_value(), "no valid host value"); | |||||
| input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); | |||||
| } | |||||
| } | |||||
| outputs->reserve(output_descs.size()); | |||||
| SmallVector<DeviceTensorND> output_tensornds; | |||||
| output_tensornds.reserve(output_descs.size()); | |||||
| for (auto&& desc : output_descs) { | |||||
| // TODO: may conflict with condtake, which need alloc inside | |||||
| mgb_assert(!desc.layout.is_empty()); | |||||
| // use HostTensorND alloc_host for cuda pinned memory | |||||
| output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); | |||||
| } | |||||
| OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); | |||||
| SmallVector<TensorInfo*> output_infos; | |||||
| output_infos.reserve(output_descs.size()); | |||||
| for (auto&& tensornd : output_tensornds) { | |||||
| // tensornd -> host_tensornd | |||||
| HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) | |||||
| .proxy_to_comp_node(output_cn); | |||||
| // tensornd -> desc | |||||
| LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; | |||||
| // tensornd -> tensor | |||||
| auto info = alloc(); | |||||
| info->desc = desc; | |||||
| m_valid_handle.insert(info); | |||||
| output_infos.push_back(info); | |||||
| info->ptr = Tensor::make(host_tensornd, true); // host_only=true | |||||
| info->value_fetched = true; | |||||
| outputs->push_back(info); | |||||
| } | |||||
| if (m_enable_evict & DROP) { | |||||
| for (auto out : output_infos) { | |||||
| out->path.op = op; | |||||
| for (auto out_ : output_infos) { | |||||
| out->path.outputs.push_back(m_st.at(out_)); | |||||
| } | |||||
| for (auto inp : input_infos) { | |||||
| out->path.inputs.push_back(m_st.at(inp)); | |||||
| inp->path.dep_outputs.push_back(m_st.at(out)); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| void ChannelImpl::dispatch_kernel( | |||||
| std::shared_ptr<OpDef> op, | |||||
| const SmallVector<TensorInfo*>& input_infos, | |||||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||||
| SmallVector<Handle>* outputs) { | |||||
| auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | ||||
| ApplyOp cmd{std::move(op)}; | ApplyOp cmd{std::move(op)}; | ||||
| cmd.inputs = std::move(input_infos); | cmd.inputs = std::move(input_infos); | ||||
| cmd.outputs.reserve(output_descs.size()); | cmd.outputs.reserve(output_descs.size()); | ||||
| SmallVector<void*> outputs; | |||||
| // FIXME: remove this check when op check is correct | |||||
| bool validated_bkp = true; | |||||
| for (size_t i = 0;i < output_descs.size();i ++) { | |||||
| auto&& desc = output_descs[i]; | |||||
| if (desc.layout.ndim == 0) { | |||||
| validated_bkp = false; | |||||
| } | |||||
| outputs->reserve(output_descs.size()); | |||||
| for (auto&& desc : output_descs) { | |||||
| auto info = alloc(); | auto info = alloc(); | ||||
| info->desc = desc; | info->desc = desc; | ||||
| m_valid_handle.insert(info); | m_valid_handle.insert(info); | ||||
| cmd.outputs.push_back(info); | cmd.outputs.push_back(info); | ||||
| outputs.push_back(info); | |||||
| outputs->push_back(info); | |||||
| } | } | ||||
| if (m_enable_evict & DROP) { | if (m_enable_evict & DROP) { | ||||
| for (auto out : cmd.outputs) { | for (auto out : cmd.outputs) { | ||||
| @@ -130,20 +176,55 @@ SmallVector<void*> ChannelImpl::apply_op( | |||||
| } | } | ||||
| } | } | ||||
| m_buffer.enqueue(std::move(cmd)); | m_buffer.enqueue(std::move(cmd)); | ||||
| if (!(validated && validated_bkp) && m_async_level == 1) { | |||||
| if (!validated && m_async_level == 1) { | |||||
| sync(); | sync(); | ||||
| } else if (m_async_level == 0) { | } else if (m_async_level == 0) { | ||||
| sync(); | sync(); | ||||
| // check device error | // check device error | ||||
| for (auto&& oup : outputs) { | |||||
| for (auto&& oup : *outputs) { | |||||
| auto info = reinterpret_cast<TensorInfo*>(oup); | auto info = reinterpret_cast<TensorInfo*>(oup); | ||||
| info->ptr->comp_node().sync(); | info->ptr->comp_node().sync(); | ||||
| } | } | ||||
| } | } | ||||
| } | |||||
| SmallVector<Handle> ChannelImpl::apply_op( | |||||
| std::shared_ptr<OpDef> op, | |||||
| const SmallVector<Handle>& inputs) { | |||||
| for (auto i : inputs) { | |||||
| mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||||
| "invalid handle: %p", i); | |||||
| } | |||||
| SmallVector<TensorInfo*> input_infos; | |||||
| input_infos.reserve(inputs.size()); | |||||
| SmallVector<LogicalTensorDesc> input_descs; | |||||
| input_descs.reserve(inputs.size()); | |||||
| { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| for (auto i : inputs) { | |||||
| auto info = reinterpret_cast<TensorInfo*>(i); | |||||
| mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||||
| input_infos.push_back(info); | |||||
| input_descs.push_back(info->desc); | |||||
| } | |||||
| } | |||||
| SmallVector<Handle> outputs; | |||||
| switch (OpDef::decide_dispatch_mode(*op, input_descs)) { | |||||
| case DEFAULT_CPU: { | |||||
| dispatch_default_cpu(op, input_infos, input_descs, &outputs); | |||||
| break; | |||||
| } | |||||
| case KERNEL: { | |||||
| dispatch_kernel(op, input_infos, input_descs, &outputs); | |||||
| break; | |||||
| } | |||||
| } | |||||
| mgb_assert(outputs.size() > 0, "Invalid dispatch mode!"); | |||||
| return outputs; | return outputs; | ||||
| } | } | ||||
| HostTensorND ChannelImpl::get_value(void* handle) { | |||||
| HostTensorND ChannelImpl::get_value(Handle handle) { | |||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
| @@ -163,7 +244,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { | |||||
| return info->ptr->get_value(); | return info->ptr->get_value(); | ||||
| } | } | ||||
| TensorShape ChannelImpl::get_shape(void* handle) { | |||||
| TensorShape ChannelImpl::get_shape(Handle handle) { | |||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
| @@ -184,7 +265,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| DType ChannelImpl::get_dtype(void* handle) { | |||||
| DType ChannelImpl::get_dtype(Handle handle) { | |||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
| @@ -193,7 +274,7 @@ DType ChannelImpl::get_dtype(void* handle) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| CompNode ChannelImpl::get_device(void* handle) { | |||||
| CompNode ChannelImpl::get_device(Handle handle) { | |||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
| @@ -202,7 +283,7 @@ CompNode ChannelImpl::get_device(void* handle) { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||||
| DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { | |||||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | ||||
| "invalid handle: %p", handle); | "invalid handle: %p", handle); | ||||
| auto info = reinterpret_cast<TensorInfo*>(handle); | auto info = reinterpret_cast<TensorInfo*>(handle); | ||||
| @@ -262,25 +343,15 @@ ChannelImpl::~ChannelImpl() { | |||||
| } | } | ||||
| void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { | void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { | ||||
| if (notice) { | |||||
| MGB_LOCK_GUARD(m_mutex); | |||||
| dest->value_fetched = ptr->value_fetched(); | |||||
| // update tensor desc for static infer | |||||
| // if (dest->desc.layout.ndim) { | |||||
| // mgb_assert(dest->desc.layout.eq_shape(ptr->layout())); | |||||
| // } | |||||
| dest->desc.layout = ptr->layout(); | |||||
| dest->desc.comp_node = ptr->comp_node(); | |||||
| dest->ptr = std::move(ptr); | |||||
| if (m_waitee == dest) { | |||||
| m_cv.notify_all(); | |||||
| } | |||||
| } else { | |||||
| dest->value_fetched = ptr->value_fetched(); | |||||
| // update tensor desc for static infer | |||||
| dest->desc.layout = ptr->layout(); | |||||
| dest->desc.comp_node = ptr->comp_node(); | |||||
| dest->ptr = std::move(ptr); | |||||
| auto lock = notice ? std::unique_lock<std::mutex>(m_mutex) | |||||
| : std::unique_lock<std::mutex>(); | |||||
| dest->value_fetched = ptr->value_fetched(); | |||||
| // update tensor desc for static infer | |||||
| dest->desc.layout = ptr->layout(); | |||||
| dest->desc.comp_node = ptr->comp_node(); | |||||
| dest->ptr = std::move(ptr); | |||||
| if (notice && m_waitee == dest) { | |||||
| m_cv.notify_all(); | |||||
| } | } | ||||
| } | } | ||||
| @@ -295,7 +366,7 @@ void ChannelImpl::do_swap_out(TensorInfo* dest) { | |||||
| dest->evict_type = SWAP; | dest->evict_type = SWAP; | ||||
| dest->value_fetched = false; | dest->value_fetched = false; | ||||
| // TODO: swap in parallel | // TODO: swap in parallel | ||||
| dest->h_value.copy_from(dest->ptr->dev_tensor()).sync(); | |||||
| dest->h_value = dest->ptr->get_value(); | |||||
| dest->ptr.reset(); | dest->ptr.reset(); | ||||
| } | } | ||||
| @@ -198,6 +198,17 @@ private: | |||||
| void do_drop(TensorInfo* dest); | void do_drop(TensorInfo* dest); | ||||
| void regenerate(TensorInfo* dest, bool must_drop); | void regenerate(TensorInfo* dest, bool must_drop); | ||||
| void dispatch_default_cpu( | |||||
| std::shared_ptr<OpDef> op, | |||||
| const SmallVector<TensorInfo*>& input_infos, | |||||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||||
| SmallVector<Handle>* outputs); | |||||
| void dispatch_kernel( | |||||
| std::shared_ptr<OpDef> op, | |||||
| const SmallVector<TensorInfo*>& input_infos, | |||||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||||
| SmallVector<Handle>* outputs); | |||||
| std::mutex m_mutex; | std::mutex m_mutex; | ||||
| std::condition_variable m_cv; | std::condition_variable m_cv; | ||||
| MemPool<TensorInfo> m_pool; | MemPool<TensorInfo> m_pool; | ||||
| @@ -30,12 +30,26 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node( | |||||
| return trait->make_from_op_node(node); | return trait->make_from_op_node(node); | ||||
| } | } | ||||
| DispatchMode OpDef::decide_dispatch_mode( | |||||
| const OpDef& def, | |||||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| return def.trait()->decide_dispatch_mode(def, inputs); | |||||
| } | |||||
| SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| SmallVector<TensorPtr> inputs) { | SmallVector<TensorPtr> inputs) { | ||||
| return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | ||||
| } | } | ||||
| void OpDef::apply_on_device_tensornd( | |||||
| const OpDef& def, | |||||
| const SmallVector<DeviceTensorND>& inputs, | |||||
| SmallVector<DeviceTensorND>* outputs) { | |||||
| def.trait()->apply_on_device_tensornd(def, inputs, outputs); | |||||
| return; | |||||
| } | |||||
| VarNodeArray OpDef::apply_on_var_node( | VarNodeArray OpDef::apply_on_var_node( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
| @@ -9,12 +9,16 @@ | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| */ | */ | ||||
| #include <exception> | |||||
| #include <sstream> | #include <sstream> | ||||
| #include <stdexcept> | |||||
| #include "megbrain/imperative/op_def.h" | |||||
| #include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| #include "megbrain/tensor.h" | |||||
| #include "./op_trait.h" | #include "./op_trait.h" | ||||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||||
| namespace mgb { | namespace mgb { | ||||
| namespace imperative { | namespace imperative { | ||||
| @@ -62,6 +66,12 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||||
| } | } | ||||
| } | } | ||||
| DispatchMode fallback_decide_dispatch_mode( | |||||
| const OpDef& def, | |||||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| return KERNEL; | |||||
| } | |||||
| OpTraitRegistry& OpTraitRegistry::fallback() { | OpTraitRegistry& OpTraitRegistry::fallback() { | ||||
| if (trait->apply_on_var_node) { | if (trait->apply_on_var_node) { | ||||
| // fallback to proxy graph impl | // fallback to proxy graph impl | ||||
| @@ -78,6 +88,9 @@ OpTraitRegistry& OpTraitRegistry::fallback() { | |||||
| proxy_graph_detail::make_backward_graph; | proxy_graph_detail::make_backward_graph; | ||||
| } | } | ||||
| } | } | ||||
| if (!trait->decide_dispatch_mode) { | |||||
| trait->decide_dispatch_mode = fallback_decide_dispatch_mode; | |||||
| } | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -60,8 +60,12 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||||
| using OpDefMaker = detail::OpMeth< | using OpDefMaker = detail::OpMeth< | ||||
| decltype(OpDef::make_from_op_node)>; | decltype(OpDef::make_from_op_node)>; | ||||
| using DecideDispatchMode = detail::OpMeth< | |||||
| decltype(OpDef::decide_dispatch_mode)>; | |||||
| using ApplyOnPhysicalTensor = detail::OpMeth< | using ApplyOnPhysicalTensor = detail::OpMeth< | ||||
| decltype(OpDef::apply_on_physical_tensor)>; | decltype(OpDef::apply_on_physical_tensor)>; | ||||
| using ApplyOnDeviceTensorND = detail::OpMeth< | |||||
| decltype(OpDef::apply_on_device_tensornd)>; | |||||
| using ApplyOnVarNode = detail::OpMeth< | using ApplyOnVarNode = detail::OpMeth< | ||||
| decltype(OpDef::apply_on_var_node)>; | decltype(OpDef::apply_on_var_node)>; | ||||
| using InferOutputAttrsFallible = detail::OpMeth< | using InferOutputAttrsFallible = detail::OpMeth< | ||||
| @@ -74,7 +78,9 @@ using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||||
| struct OpTrait { | struct OpTrait { | ||||
| const char* name; | const char* name; | ||||
| OpDefMaker make_from_op_node; | OpDefMaker make_from_op_node; | ||||
| DecideDispatchMode decide_dispatch_mode; | |||||
| ApplyOnPhysicalTensor apply_on_physical_tensor; | ApplyOnPhysicalTensor apply_on_physical_tensor; | ||||
| ApplyOnDeviceTensorND apply_on_device_tensornd; | |||||
| ApplyOnVarNode apply_on_var_node; | ApplyOnVarNode apply_on_var_node; | ||||
| InferOutputAttrsFallible infer_output_attrs_fallible; | InferOutputAttrsFallible infer_output_attrs_fallible; | ||||
| GradMaker make_backward_graph; | GradMaker make_backward_graph; | ||||
| @@ -88,7 +94,9 @@ struct OpTrait { | |||||
| #define FOR_EACH_OP_METH(cb) \ | #define FOR_EACH_OP_METH(cb) \ | ||||
| cb(make_from_op_node) \ | cb(make_from_op_node) \ | ||||
| cb(decide_dispatch_mode) \ | |||||
| cb(apply_on_physical_tensor) \ | cb(apply_on_physical_tensor) \ | ||||
| cb(apply_on_device_tensornd) \ | |||||
| cb(apply_on_var_node) \ | cb(apply_on_var_node) \ | ||||
| cb(infer_output_attrs_fallible) \ | cb(infer_output_attrs_fallible) \ | ||||
| cb(make_backward_graph) \ | cb(make_backward_graph) \ | ||||
| @@ -68,23 +68,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; | return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; | ||||
| } | } | ||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| DispatchMode decide_dispatch_mode( | |||||
| const OpDef& def, | const OpDef& def, | ||||
| const SmallVector<TensorPtr>& inputs) { | |||||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| bool host_computable = true; | |||||
| constexpr int size_threshhold = TensorShape::MAX_NDIM; | |||||
| for (auto&& inp : inputs) { | |||||
| if (inp.value.empty() || inp.value.layout().ndim == 0 | |||||
| || inp.value.layout().total_nr_elems() > size_threshhold) { | |||||
| host_computable = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| return host_computable ? DEFAULT_CPU : KERNEL; | |||||
| } | |||||
| void apply_on_device_tensornd( | |||||
| const OpDef& def, | |||||
| const SmallVector<DeviceTensorND>& inputs, | |||||
| SmallVector<DeviceTensorND>* outputs) { | |||||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | auto&& op_def = def.cast_final_safe<Elemwise>(); | ||||
| auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | ||||
| mgb_assert(inputs.size() == trait.arity, | mgb_assert(inputs.size() == trait.arity, | ||||
| "%s expects %u inputs; got %zu actually", trait.name, | "%s expects %u inputs; got %zu actually", trait.name, | ||||
| trait.arity, inputs.size()); | trait.arity, inputs.size()); | ||||
| auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0].comp_node()); | |||||
| opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, | |||||
| const SmallVector<TensorPtr>& inputs) { | |||||
| DeviceTensorND out; | |||||
| SmallVector<DeviceTensorND> dt_inputs(inputs.size()); | |||||
| SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||||
| for (unsigned i = 0; i < inputs.size(); ++i){ | for (unsigned i = 0; i < inputs.size(); ++i){ | ||||
| dt_inputs[i] = inputs[i]->dev_tensor(); | |||||
| inp_tensornds[i] = inputs[i]->dev_tensor(); | |||||
| } | } | ||||
| auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node()); | |||||
| opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr); | |||||
| return {Tensor::make(out)}; | |||||
| SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}}; | |||||
| apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | |||||
| return {Tensor::make(oup_tensornds[0])}; | |||||
| } | } | ||||
| MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | ||||
| @@ -214,8 +237,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_ | |||||
| OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | ||||
| .make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
| .decide_dispatch_mode(decide_dispatch_mode) | |||||
| .apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
| .fallback(); | .fallback(); | ||||
| @@ -15,8 +15,8 @@ | |||||
| #include "../op_trait.h" | #include "../op_trait.h" | ||||
| namespace mgb::imperative { | namespace mgb::imperative { | ||||
| namespace { | |||||
| namespace get_var_shape { | |||||
| cg::OperatorNodeBase* apply_on_var_node( | cg::OperatorNodeBase* apply_on_var_node( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| const VarNodeArray& inputs) { | const VarNodeArray& inputs) { | ||||
| @@ -24,17 +24,38 @@ cg::OperatorNodeBase* apply_on_var_node( | |||||
| return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | ||||
| } | } | ||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| DispatchMode decide_dispatch_mode( | |||||
| const OpDef& def, | const OpDef& def, | ||||
| const SmallVector<TensorPtr>& inputs) { | |||||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||||
| bool host_computable = true; | |||||
| for (auto&& inp : inputs) { | |||||
| // FIXME(czh): remove value chech after proxy graph's | |||||
| // apply_on_device_tensornd is supported and output Tensor | |||||
| // is made before add_task. | |||||
| // then if layout is valid, ptr->layout must be ready | |||||
| if (inp.value.empty() || inp.value.layout().ndim == 0) { | |||||
| host_computable = false; | |||||
| break; | |||||
| } | |||||
| } | |||||
| return host_computable ? DEFAULT_CPU : KERNEL; | |||||
| } | |||||
| void apply_on_device_tensornd( | |||||
| const OpDef& def, | |||||
| const SmallVector<DeviceTensorND>& inputs, | |||||
| SmallVector<DeviceTensorND>* outputs) { | |||||
| auto&& op_def = def.cast_final_safe<GetVarShape>(); | auto&& op_def = def.cast_final_safe<GetVarShape>(); | ||||
| mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | ||||
| auto&& inp = inputs[0]; | auto&& inp = inputs[0]; | ||||
| auto&& shp = inp->layout(); | |||||
| auto&& shp = inp.layout(); | |||||
| mgb_assert(shp.ndim != 0, "input shape invalid"); | mgb_assert(shp.ndim != 0, "input shape invalid"); | ||||
| mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), | |||||
| "GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); | |||||
| HostTensorND hv; | HostTensorND hv; | ||||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||||
| hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32()); | |||||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||||
| hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32()); | |||||
| auto* ptr = hv.ptr<dt_int32>(); | auto* ptr = hv.ptr<dt_int32>(); | ||||
| for (size_t i = 0; i < shp.ndim; ++i) { | for (size_t i = 0; i < shp.ndim; ++i) { | ||||
| ptr[i] = shp.shape[i]; | ptr[i] = shp.shape[i]; | ||||
| @@ -45,11 +66,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| axis += shp.ndim; | axis += shp.ndim; | ||||
| } | } | ||||
| mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); | mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); | ||||
| hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); | |||||
| hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); | |||||
| auto* ptr = hv.ptr<dt_int32>(); | auto* ptr = hv.ptr<dt_int32>(); | ||||
| ptr[0] = shp.shape[axis]; | ptr[0] = shp.shape[axis]; | ||||
| } | } | ||||
| return {Tensor::make(std::move(hv))}; | |||||
| (*outputs)[0] = DeviceTensorND::make_proxy(hv); | |||||
| } | |||||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| const OpDef& def, | |||||
| const SmallVector<TensorPtr>& inputs) { | |||||
| SmallVector<DeviceTensorND> input_tensornds; | |||||
| input_tensornds.reserve(inputs.size()); | |||||
| for (auto&& inp : inputs) { | |||||
| input_tensornds.push_back(inp->dev_tensor()); | |||||
| } | |||||
| SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}}; | |||||
| apply_on_device_tensornd(def, input_tensornds, &output_tensornds); | |||||
| // restore to input comp_node | |||||
| HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0]) | |||||
| .proxy_to_comp_node(inputs[0]->comp_node()); | |||||
| return {Tensor::make(std::move(host_tensornd))}; | |||||
| } | } | ||||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | ||||
| @@ -62,7 +101,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | ||||
| } | } | ||||
| DeviceTensorND value; | DeviceTensorND value; | ||||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||||
| value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); | value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); | ||||
| auto* ptr = value.ptr<dt_int32>(); | auto* ptr = value.ptr<dt_int32>(); | ||||
| for (size_t i = 0; i < desc.layout.ndim; ++i) { | for (size_t i = 0; i < desc.layout.ndim; ++i) { | ||||
| @@ -88,11 +127,15 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||||
| OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | ||||
| .make_from_op_node(make_from_op_node) | .make_from_op_node(make_from_op_node) | ||||
| .decide_dispatch_mode(decide_dispatch_mode) | |||||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | .infer_output_attrs_fallible(infer_output_attrs_fallible) | ||||
| .apply_on_var_node(apply_on_var_node) | .apply_on_var_node(apply_on_var_node) | ||||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||||
| .apply_on_physical_tensor(apply_on_physical_tensor) | .apply_on_physical_tensor(apply_on_physical_tensor) | ||||
| .fallback(); | .fallback(); | ||||
| } // get_var_shape | |||||
| namespace param_pack { | |||||
| TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | ||||
| TensorShapeArray ret; | TensorShapeArray ret; | ||||
| for (auto&& i:shapes) { | for (auto&& i:shapes) { | ||||
| @@ -156,6 +199,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||||
| OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | ||||
| .apply_on_var_node(param_pack_concat_apply_on_var_node) | .apply_on_var_node(param_pack_concat_apply_on_var_node) | ||||
| .fallback(); | .fallback(); | ||||
| } // namespace | |||||
| } // param_pack | |||||
| } // namespace mgb::imperative | } // namespace mgb::imperative | ||||
| @@ -20,6 +20,11 @@ namespace imperative { | |||||
| class OpDef; | class OpDef; | ||||
| struct OpTrait; | struct OpTrait; | ||||
| enum DispatchMode { | |||||
| DEFAULT_CPU = 0, | |||||
| KERNEL = 1 | |||||
| }; | |||||
| struct BackwardGraphResult { | struct BackwardGraphResult { | ||||
| std::shared_ptr<OpDef> backward; | std::shared_ptr<OpDef> backward; | ||||
| std::vector<bool> save_for_backward; | std::vector<bool> save_for_backward; | ||||
| @@ -36,10 +41,31 @@ public: | |||||
| static std::shared_ptr<OpDef> make_from_op_node( | static std::shared_ptr<OpDef> make_from_op_node( | ||||
| cg::OperatorNodeBase* node); | cg::OperatorNodeBase* node); | ||||
| /*! | |||||
| * \brief Decide which dispatch method to be used according to the inputs' | |||||
| * host value and size. | |||||
| * | |||||
| * \param def Specific :c:expr:`OpDef` to be executed. | |||||
| * \param inputs Input tensor descriptions. | |||||
| * \return Which DispatchMode to be used, such as `CUDA` or `DEFAULT_CPU`. | |||||
| */ | |||||
| static DispatchMode decide_dispatch_mode( | |||||
| const OpDef& def, | |||||
| const SmallVector<LogicalTensorDesc>& inputs); | |||||
| static SmallVector<TensorPtr> apply_on_physical_tensor( | static SmallVector<TensorPtr> apply_on_physical_tensor( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| SmallVector<TensorPtr> inputs); | SmallVector<TensorPtr> inputs); | ||||
| /*! | |||||
| * \brief Call the corresponding dnn op to calculate results. Output | |||||
| * tensors' device memory should be allocated outside. | |||||
| */ | |||||
| static void apply_on_device_tensornd( | |||||
| const OpDef& def, | |||||
| const SmallVector<DeviceTensorND>& inputs, | |||||
| SmallVector<DeviceTensorND>* outputs); | |||||
| static cg::VarNodeArray apply_on_var_node( | static cg::VarNodeArray apply_on_var_node( | ||||
| const OpDef& def, | const OpDef& def, | ||||
| const VarNodeArray& inputs); | const VarNodeArray& inputs); | ||||