GitOrigin-RevId: 2663504470
tags/v1.3.0
| @@ -29,7 +29,7 @@ Interpreter& Interpreter::inst() { | |||
| return inst_; | |||
| } | |||
| void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||
| Handle ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||
| auto info = alloc(); | |||
| info->desc.layout = value.layout(); | |||
| info->desc.comp_node = value.comp_node(); | |||
| @@ -39,7 +39,7 @@ void* ChannelImpl::put(const HostTensorND& value, bool no_cache) { | |||
| return info; | |||
| } | |||
| void* ChannelImpl::put(const DeviceTensorND& data) { | |||
| Handle ChannelImpl::put(const DeviceTensorND& data) { | |||
| auto info = alloc(); | |||
| info->desc.layout = data.layout(); | |||
| info->desc.comp_node = data.comp_node(); | |||
| @@ -48,12 +48,12 @@ void* ChannelImpl::put(const DeviceTensorND& data) { | |||
| return info; | |||
| } | |||
| void ChannelImpl::del(void* handle) { | |||
| void ChannelImpl::del(Handle handle) { | |||
| mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | |||
| m_buffer.enqueue(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||
| } | |||
| void ChannelImpl::swap_in(void* handle) { | |||
| void ChannelImpl::swap_in(Handle handle) { | |||
| if (m_enable_evict & SWAP) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| @@ -61,7 +61,7 @@ void ChannelImpl::swap_in(void* handle) { | |||
| } | |||
| } | |||
| void ChannelImpl::swap_out(void* handle) { | |||
| void ChannelImpl::swap_out(Handle handle) { | |||
| if (m_enable_evict & SWAP) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| @@ -69,7 +69,7 @@ void ChannelImpl::swap_out(void* handle) { | |||
| } | |||
| } | |||
| void ChannelImpl::drop(void* handle) { | |||
| void ChannelImpl::drop(Handle handle) { | |||
| if (m_enable_evict & DROP) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| @@ -77,45 +77,91 @@ void ChannelImpl::drop(void* handle) { | |||
| } | |||
| } | |||
| SmallVector<void*> ChannelImpl::apply_op( | |||
| void ChannelImpl::dispatch_default_cpu( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<void*>& inputs) { | |||
| for (auto i : inputs) { | |||
| mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||
| "invalid handle: %p", i); | |||
| } | |||
| SmallVector<TensorInfo*> input_infos; | |||
| input_infos.reserve(inputs.size()); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| input_descs.reserve(inputs.size()); | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs) { | |||
| auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| SmallVector<DeviceTensorND> input_tensornds; | |||
| input_tensornds.reserve(input_descs.size()); | |||
| CompNode output_cn; | |||
| { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| for (auto i : inputs) { | |||
| auto info = reinterpret_cast<TensorInfo*>(i); | |||
| mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||
| input_infos.push_back(info); | |||
| input_descs.push_back(info->desc); | |||
| for (auto&& info : input_infos) { | |||
| mgb_assert(info->ptr, "invalid tensor ptr!"); | |||
| if (!output_cn.valid()) { | |||
| output_cn = info->ptr->comp_node(); | |||
| } else { | |||
| mgb_assert(output_cn == info->ptr->comp_node(), "cannot decide output comp node"); | |||
| } | |||
| mgb_assert(info->ptr->try_get_value(), "no valid host value"); | |||
| input_tensornds.emplace_back(info->ptr->get_value().proxy_to_default_cpu()); | |||
| } | |||
| } | |||
| outputs->reserve(output_descs.size()); | |||
| SmallVector<DeviceTensorND> output_tensornds; | |||
| output_tensornds.reserve(output_descs.size()); | |||
| for (auto&& desc : output_descs) { | |||
| // TODO: may conflict with condtake, which need alloc inside | |||
| mgb_assert(!desc.layout.is_empty()); | |||
| // use HostTensorND alloc_host for cuda pinned memory | |||
| output_tensornds.emplace_back(HostTensorND(output_cn, desc.layout).proxy_to_default_cpu()); | |||
| } | |||
| OpDef::apply_on_device_tensornd(*op, input_tensornds, &output_tensornds); | |||
| SmallVector<TensorInfo*> output_infos; | |||
| output_infos.reserve(output_descs.size()); | |||
| for (auto&& tensornd : output_tensornds) { | |||
| // tensornd -> host_tensornd | |||
| HostTensorND host_tensornd = HostTensorND::make_proxy(tensornd) | |||
| .proxy_to_comp_node(output_cn); | |||
| // tensornd -> desc | |||
| LogicalTensorDesc desc = {tensornd.layout(), output_cn, tensornd}; | |||
| // tensornd -> tensor | |||
| auto info = alloc(); | |||
| info->desc = desc; | |||
| m_valid_handle.insert(info); | |||
| output_infos.push_back(info); | |||
| info->ptr = Tensor::make(host_tensornd, true); // host_only=true | |||
| info->value_fetched = true; | |||
| outputs->push_back(info); | |||
| } | |||
| if (m_enable_evict & DROP) { | |||
| for (auto out : output_infos) { | |||
| out->path.op = op; | |||
| for (auto out_ : output_infos) { | |||
| out->path.outputs.push_back(m_st.at(out_)); | |||
| } | |||
| for (auto inp : input_infos) { | |||
| out->path.inputs.push_back(m_st.at(inp)); | |||
| inp->path.dep_outputs.push_back(m_st.at(out)); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void ChannelImpl::dispatch_kernel( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs) { | |||
| auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| ApplyOp cmd{std::move(op)}; | |||
| cmd.inputs = std::move(input_infos); | |||
| cmd.outputs.reserve(output_descs.size()); | |||
| SmallVector<void*> outputs; | |||
| // FIXME: remove this check when op check is correct | |||
| bool validated_bkp = true; | |||
| for (size_t i = 0;i < output_descs.size();i ++) { | |||
| auto&& desc = output_descs[i]; | |||
| if (desc.layout.ndim == 0) { | |||
| validated_bkp = false; | |||
| } | |||
| outputs->reserve(output_descs.size()); | |||
| for (auto&& desc : output_descs) { | |||
| auto info = alloc(); | |||
| info->desc = desc; | |||
| m_valid_handle.insert(info); | |||
| cmd.outputs.push_back(info); | |||
| outputs.push_back(info); | |||
| outputs->push_back(info); | |||
| } | |||
| if (m_enable_evict & DROP) { | |||
| for (auto out : cmd.outputs) { | |||
| @@ -130,20 +176,55 @@ SmallVector<void*> ChannelImpl::apply_op( | |||
| } | |||
| } | |||
| m_buffer.enqueue(std::move(cmd)); | |||
| if (!(validated && validated_bkp) && m_async_level == 1) { | |||
| if (!validated && m_async_level == 1) { | |||
| sync(); | |||
| } else if (m_async_level == 0) { | |||
| sync(); | |||
| // check device error | |||
| for (auto&& oup : outputs) { | |||
| for (auto&& oup : *outputs) { | |||
| auto info = reinterpret_cast<TensorInfo*>(oup); | |||
| info->ptr->comp_node().sync(); | |||
| } | |||
| } | |||
| } | |||
| SmallVector<Handle> ChannelImpl::apply_op( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<Handle>& inputs) { | |||
| for (auto i : inputs) { | |||
| mgb_assert(m_valid_handle.find(i) != m_valid_handle.end(), | |||
| "invalid handle: %p", i); | |||
| } | |||
| SmallVector<TensorInfo*> input_infos; | |||
| input_infos.reserve(inputs.size()); | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| input_descs.reserve(inputs.size()); | |||
| { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| for (auto i : inputs) { | |||
| auto info = reinterpret_cast<TensorInfo*>(i); | |||
| mgb_assert(!info->invalid, "Invalid tensor, unable to apply_op!"); | |||
| input_infos.push_back(info); | |||
| input_descs.push_back(info->desc); | |||
| } | |||
| } | |||
| SmallVector<Handle> outputs; | |||
| switch (OpDef::decide_dispatch_mode(*op, input_descs)) { | |||
| case DEFAULT_CPU: { | |||
| dispatch_default_cpu(op, input_infos, input_descs, &outputs); | |||
| break; | |||
| } | |||
| case KERNEL: { | |||
| dispatch_kernel(op, input_infos, input_descs, &outputs); | |||
| break; | |||
| } | |||
| } | |||
| mgb_assert(outputs.size() > 0, "Invalid dispatch mode!"); | |||
| return outputs; | |||
| } | |||
| HostTensorND ChannelImpl::get_value(void* handle) { | |||
| HostTensorND ChannelImpl::get_value(Handle handle) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -163,7 +244,7 @@ HostTensorND ChannelImpl::get_value(void* handle) { | |||
| return info->ptr->get_value(); | |||
| } | |||
| TensorShape ChannelImpl::get_shape(void* handle) { | |||
| TensorShape ChannelImpl::get_shape(Handle handle) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -184,7 +265,7 @@ TensorShape ChannelImpl::get_shape(void* handle) { | |||
| return ret; | |||
| } | |||
| DType ChannelImpl::get_dtype(void* handle) { | |||
| DType ChannelImpl::get_dtype(Handle handle) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -193,7 +274,7 @@ DType ChannelImpl::get_dtype(void* handle) { | |||
| return ret; | |||
| } | |||
| CompNode ChannelImpl::get_device(void* handle) { | |||
| CompNode ChannelImpl::get_device(Handle handle) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -202,7 +283,7 @@ CompNode ChannelImpl::get_device(void* handle) { | |||
| return ret; | |||
| } | |||
| DeviceTensorND ChannelImpl::get_dev_tensor(void* handle) { | |||
| DeviceTensorND ChannelImpl::get_dev_tensor(Handle handle) { | |||
| mgb_assert(m_valid_handle.find(handle) != m_valid_handle.end(), | |||
| "invalid handle: %p", handle); | |||
| auto info = reinterpret_cast<TensorInfo*>(handle); | |||
| @@ -262,25 +343,15 @@ ChannelImpl::~ChannelImpl() { | |||
| } | |||
| void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr, bool notice = true) { | |||
| if (notice) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| // if (dest->desc.layout.ndim) { | |||
| // mgb_assert(dest->desc.layout.eq_shape(ptr->layout())); | |||
| // } | |||
| dest->desc.layout = ptr->layout(); | |||
| dest->desc.comp_node = ptr->comp_node(); | |||
| dest->ptr = std::move(ptr); | |||
| if (m_waitee == dest) { | |||
| m_cv.notify_all(); | |||
| } | |||
| } else { | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| dest->desc.layout = ptr->layout(); | |||
| dest->desc.comp_node = ptr->comp_node(); | |||
| dest->ptr = std::move(ptr); | |||
| auto lock = notice ? std::unique_lock<std::mutex>(m_mutex) | |||
| : std::unique_lock<std::mutex>(); | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| dest->desc.layout = ptr->layout(); | |||
| dest->desc.comp_node = ptr->comp_node(); | |||
| dest->ptr = std::move(ptr); | |||
| if (notice && m_waitee == dest) { | |||
| m_cv.notify_all(); | |||
| } | |||
| } | |||
| @@ -295,7 +366,7 @@ void ChannelImpl::do_swap_out(TensorInfo* dest) { | |||
| dest->evict_type = SWAP; | |||
| dest->value_fetched = false; | |||
| // TODO: swap in parallel | |||
| dest->h_value.copy_from(dest->ptr->dev_tensor()).sync(); | |||
| dest->h_value = dest->ptr->get_value(); | |||
| dest->ptr.reset(); | |||
| } | |||
| @@ -198,6 +198,17 @@ private: | |||
| void do_drop(TensorInfo* dest); | |||
| void regenerate(TensorInfo* dest, bool must_drop); | |||
| void dispatch_default_cpu( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs); | |||
| void dispatch_kernel( | |||
| std::shared_ptr<OpDef> op, | |||
| const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| SmallVector<Handle>* outputs); | |||
| std::mutex m_mutex; | |||
| std::condition_variable m_cv; | |||
| MemPool<TensorInfo> m_pool; | |||
| @@ -30,12 +30,26 @@ std::shared_ptr<OpDef> OpDef::make_from_op_node( | |||
| return trait->make_from_op_node(node); | |||
| } | |||
| DispatchMode OpDef::decide_dispatch_mode( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| return def.trait()->decide_dispatch_mode(def, inputs); | |||
| } | |||
| SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
| const OpDef& def, | |||
| SmallVector<TensorPtr> inputs) { | |||
| return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | |||
| } | |||
| void OpDef::apply_on_device_tensornd( | |||
| const OpDef& def, | |||
| const SmallVector<DeviceTensorND>& inputs, | |||
| SmallVector<DeviceTensorND>* outputs) { | |||
| def.trait()->apply_on_device_tensornd(def, inputs, outputs); | |||
| return; | |||
| } | |||
| VarNodeArray OpDef::apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| @@ -9,12 +9,16 @@ | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| */ | |||
| #include <exception> | |||
| #include <sstream> | |||
| #include <stdexcept> | |||
| #include "megbrain/imperative/op_def.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| #include "megbrain/tensor.h" | |||
| #include "./op_trait.h" | |||
| #include "megbrain/imperative/proxy_graph_detail.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| @@ -62,6 +66,12 @@ void OpTrait::for_each_trait(thin_function<void(OpTrait&)> visitor){ | |||
| } | |||
| } | |||
| DispatchMode fallback_decide_dispatch_mode( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| return KERNEL; | |||
| } | |||
| OpTraitRegistry& OpTraitRegistry::fallback() { | |||
| if (trait->apply_on_var_node) { | |||
| // fallback to proxy graph impl | |||
| @@ -78,6 +88,9 @@ OpTraitRegistry& OpTraitRegistry::fallback() { | |||
| proxy_graph_detail::make_backward_graph; | |||
| } | |||
| } | |||
| if (!trait->decide_dispatch_mode) { | |||
| trait->decide_dispatch_mode = fallback_decide_dispatch_mode; | |||
| } | |||
| return *this; | |||
| } | |||
| @@ -60,8 +60,12 @@ struct ToVarNodeArray<cg::OperatorNodeBase*>: std::true_type { | |||
| using OpDefMaker = detail::OpMeth< | |||
| decltype(OpDef::make_from_op_node)>; | |||
| using DecideDispatchMode = detail::OpMeth< | |||
| decltype(OpDef::decide_dispatch_mode)>; | |||
| using ApplyOnPhysicalTensor = detail::OpMeth< | |||
| decltype(OpDef::apply_on_physical_tensor)>; | |||
| using ApplyOnDeviceTensorND = detail::OpMeth< | |||
| decltype(OpDef::apply_on_device_tensornd)>; | |||
| using ApplyOnVarNode = detail::OpMeth< | |||
| decltype(OpDef::apply_on_var_node)>; | |||
| using InferOutputAttrsFallible = detail::OpMeth< | |||
| @@ -74,7 +78,9 @@ using IsSame = detail::OpMeth<bool(const OpDef&, const OpDef&)>; | |||
| struct OpTrait { | |||
| const char* name; | |||
| OpDefMaker make_from_op_node; | |||
| DecideDispatchMode decide_dispatch_mode; | |||
| ApplyOnPhysicalTensor apply_on_physical_tensor; | |||
| ApplyOnDeviceTensorND apply_on_device_tensornd; | |||
| ApplyOnVarNode apply_on_var_node; | |||
| InferOutputAttrsFallible infer_output_attrs_fallible; | |||
| GradMaker make_backward_graph; | |||
| @@ -88,7 +94,9 @@ struct OpTrait { | |||
| #define FOR_EACH_OP_METH(cb) \ | |||
| cb(make_from_op_node) \ | |||
| cb(decide_dispatch_mode) \ | |||
| cb(apply_on_physical_tensor) \ | |||
| cb(apply_on_device_tensornd) \ | |||
| cb(apply_on_var_node) \ | |||
| cb(infer_output_attrs_fallible) \ | |||
| cb(make_backward_graph) \ | |||
| @@ -68,23 +68,46 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| DispatchMode decide_dispatch_mode( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| bool host_computable = true; | |||
| constexpr int size_threshhold = TensorShape::MAX_NDIM; | |||
| for (auto&& inp : inputs) { | |||
| if (inp.value.empty() || inp.value.layout().ndim == 0 | |||
| || inp.value.layout().total_nr_elems() > size_threshhold) { | |||
| host_computable = false; | |||
| break; | |||
| } | |||
| } | |||
| return host_computable ? DEFAULT_CPU : KERNEL; | |||
| } | |||
| void apply_on_device_tensornd( | |||
| const OpDef& def, | |||
| const SmallVector<DeviceTensorND>& inputs, | |||
| SmallVector<DeviceTensorND>* outputs) { | |||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
| auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | |||
| mgb_assert(inputs.size() == trait.arity, | |||
| "%s expects %u inputs; got %zu actually", trait.name, | |||
| trait.arity, inputs.size()); | |||
| auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0].comp_node()); | |||
| opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| DeviceTensorND out; | |||
| SmallVector<DeviceTensorND> dt_inputs(inputs.size()); | |||
| SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||
| for (unsigned i = 0; i < inputs.size(); ++i){ | |||
| dt_inputs[i] = inputs[i]->dev_tensor(); | |||
| inp_tensornds[i] = inputs[i]->dev_tensor(); | |||
| } | |||
| auto&& dnn_opr = opr::intl::create_megdnn_opr<megdnn::Elemwise>(inputs[0]->comp_node()); | |||
| opr::Elemwise::perform(op_def.mode, out, dt_inputs, dnn_opr); | |||
| return {Tensor::make(out)}; | |||
| SmallVector<DeviceTensorND> oup_tensornds = {{inp_tensornds[0].comp_node(), inp_tensornds[0].dtype()}}; | |||
| apply_on_device_tensornd(def, inp_tensornds, &oup_tensornds); | |||
| return {Tensor::make(oup_tensornds[0])}; | |||
| } | |||
| MGB_DEFINE_OPR_CLASS(ForceInplaceElemwise, cg::SingleCNOperatorNodeBaseT<opr::mixin::MegDNNOprHolder>) //{ | |||
| @@ -214,8 +237,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_ | |||
| OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||
| .make_from_op_node(make_from_op_node) | |||
| .decide_dispatch_mode(decide_dispatch_mode) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .fallback(); | |||
| @@ -15,8 +15,8 @@ | |||
| #include "../op_trait.h" | |||
| namespace mgb::imperative { | |||
| namespace { | |||
| namespace get_var_shape { | |||
| cg::OperatorNodeBase* apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| @@ -24,17 +24,38 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| return opr::GetVarShape::make(inputs, op_def.param()).node()->owner_opr(); | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| DispatchMode decide_dispatch_mode( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| bool host_computable = true; | |||
| for (auto&& inp : inputs) { | |||
| // FIXME(czh): remove value chech after proxy graph's | |||
| // apply_on_device_tensornd is supported and output Tensor | |||
| // is made before add_task. | |||
| // then if layout is valid, ptr->layout must be ready | |||
| if (inp.value.empty() || inp.value.layout().ndim == 0) { | |||
| host_computable = false; | |||
| break; | |||
| } | |||
| } | |||
| return host_computable ? DEFAULT_CPU : KERNEL; | |||
| } | |||
| void apply_on_device_tensornd( | |||
| const OpDef& def, | |||
| const SmallVector<DeviceTensorND>& inputs, | |||
| SmallVector<DeviceTensorND>* outputs) { | |||
| auto&& op_def = def.cast_final_safe<GetVarShape>(); | |||
| mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | |||
| auto&& inp = inputs[0]; | |||
| auto&& shp = inp->layout(); | |||
| auto&& shp = inp.layout(); | |||
| mgb_assert(shp.ndim != 0, "input shape invalid"); | |||
| mgb_assert((*outputs)[0].comp_node() == CompNode::default_cpu(), | |||
| "GetVarShape's apply_on_device_tensornd should receive default_cpu outputs."); | |||
| HostTensorND hv; | |||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||
| hv = HostTensorND(inp->comp_node(), {shp.ndim}, dtype::Int32()); | |||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||
| hv = HostTensorND(CompNode::default_cpu(), {shp.ndim}, dtype::Int32()); | |||
| auto* ptr = hv.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < shp.ndim; ++i) { | |||
| ptr[i] = shp.shape[i]; | |||
| @@ -45,11 +66,29 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| axis += shp.ndim; | |||
| } | |||
| mgb_assert(axis >= 0 && axis < (int32_t)shp.ndim); | |||
| hv = HostTensorND(inp->comp_node(), {1}, dtype::Int32()); | |||
| hv = HostTensorND(CompNode::default_cpu(), {1}, dtype::Int32()); | |||
| auto* ptr = hv.ptr<dt_int32>(); | |||
| ptr[0] = shp.shape[axis]; | |||
| } | |||
| return {Tensor::make(std::move(hv))}; | |||
| (*outputs)[0] = DeviceTensorND::make_proxy(hv); | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs) { | |||
| SmallVector<DeviceTensorND> input_tensornds; | |||
| input_tensornds.reserve(inputs.size()); | |||
| for (auto&& inp : inputs) { | |||
| input_tensornds.push_back(inp->dev_tensor()); | |||
| } | |||
| SmallVector<DeviceTensorND> output_tensornds = {{CompNode::default_cpu(), dtype::Int32()}}; | |||
| apply_on_device_tensornd(def, input_tensornds, &output_tensornds); | |||
| // restore to input comp_node | |||
| HostTensorND host_tensornd = HostTensorND::make_proxy(output_tensornds[0]) | |||
| .proxy_to_comp_node(inputs[0]->comp_node()); | |||
| return {Tensor::make(std::move(host_tensornd))}; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| @@ -62,7 +101,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, false}; | |||
| } | |||
| DeviceTensorND value; | |||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS){ | |||
| if (op_def.axis == opr::GetVarShape::Param::INVALID_AXIS) { | |||
| value = DeviceTensorND(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); | |||
| auto* ptr = value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < desc.layout.ndim; ++i) { | |||
| @@ -88,11 +127,15 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | |||
| .make_from_op_node(make_from_op_node) | |||
| .decide_dispatch_mode(decide_dispatch_mode) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .fallback(); | |||
| } // get_var_shape | |||
| namespace param_pack { | |||
| TensorShapeArray get_shapes(const std::vector<std::vector<size_t>>& shapes) { | |||
| TensorShapeArray ret; | |||
| for (auto&& i:shapes) { | |||
| @@ -156,6 +199,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||
| OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | |||
| .apply_on_var_node(param_pack_concat_apply_on_var_node) | |||
| .fallback(); | |||
| } // namespace | |||
| } // param_pack | |||
| } // namespace mgb::imperative | |||
| @@ -20,6 +20,11 @@ namespace imperative { | |||
| class OpDef; | |||
| struct OpTrait; | |||
| enum DispatchMode { | |||
| DEFAULT_CPU = 0, | |||
| KERNEL = 1 | |||
| }; | |||
| struct BackwardGraphResult { | |||
| std::shared_ptr<OpDef> backward; | |||
| std::vector<bool> save_for_backward; | |||
| @@ -36,10 +41,31 @@ public: | |||
| static std::shared_ptr<OpDef> make_from_op_node( | |||
| cg::OperatorNodeBase* node); | |||
| /*! | |||
| * \brief Decide which dispatch method to be used according to the inputs' | |||
| * host value and size. | |||
| * | |||
| * \param def Specific :c:expr:`OpDef` to be executed. | |||
| * \param inputs Input tensor descriptions. | |||
| * \return Which DispatchMode to be used, such as `CUDA` or `DEFAULT_CPU`. | |||
| */ | |||
| static DispatchMode decide_dispatch_mode( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs); | |||
| static SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, | |||
| SmallVector<TensorPtr> inputs); | |||
| /*! | |||
| * \brief Call the corresponding dnn op to calculate results. Output | |||
| * tensors' device memory should be allocated outside. | |||
| */ | |||
| static void apply_on_device_tensornd( | |||
| const OpDef& def, | |||
| const SmallVector<DeviceTensorND>& inputs, | |||
| SmallVector<DeviceTensorND>* outputs); | |||
| static cg::VarNodeArray apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs); | |||