GitOrigin-RevId: bff62b33a0
tags/v1.9.0
| @@ -12,13 +12,3 @@ from contextlib import contextmanager | |||
| from ._imperative_rt.core2 import get_option, set_option | |||
| from .tensor.megbrain_graph import Graph | |||
| @contextmanager | |||
| def option(key, value): | |||
| value = int(value) | |||
| old = get_option(key) | |||
| set_option(key, value) | |||
| yield | |||
| assert get_option(key) == value | |||
| set_option(key, old) | |||
| @@ -76,10 +76,11 @@ def test_drop_basic(): | |||
| def test_finalize(): | |||
| prog = """ | |||
| import megengine | |||
| with megengine.core.option("enable_host_compute", 0): | |||
| x = megengine.tensor(0) | |||
| y = x + 1 | |||
| y.numpy() | |||
| megengine.core.set_option("enable_host_compute", 0) | |||
| x = megengine.tensor(0) | |||
| y = x + 1 | |||
| y.numpy() | |||
| megengine.core.set_option("enable_host_compute", 1) | |||
| """ | |||
| subprocess.check_call([sys.executable, "-c", prog]) | |||
| @@ -15,7 +15,6 @@ import pytest | |||
| from megengine import Parameter | |||
| from megengine import distributed as dist | |||
| from megengine import tensor | |||
| from megengine.core import option | |||
| from megengine.jit import trace | |||
| from megengine.module import Module | |||
| from megengine.utils.profiler import Profiler, scope | |||
| @@ -155,7 +155,6 @@ TensorInfo* ChannelImpl::put_impl(const HostTensorND& value, bool no_cache) { | |||
| info->h_value = value; | |||
| info->desc.value = value.proxy_to_default_cpu(); | |||
| } | |||
| info->mem_desc.id = StorageIdentifier::make(++m_storage_id); | |||
| m_worker.add_task( | |||
| {Profiler::next_id(), Put{info, value, no_cache}, | |||
| get_channel_state().stack_manager.dump()}); | |||
| @@ -180,7 +179,6 @@ TensorInfo* ChannelImpl::put_impl( | |||
| auto info = alloc(); | |||
| MGB_RECORD_EVENT(TensorCommandEvent, info->id, TensorCommandKind::Put); | |||
| init(info, {data.layout(), data.comp_node()}); | |||
| info->mem_desc.id = StorageIdentifier::make(++m_storage_id); | |||
| info->ptr = Tensor::make(data, hvalue); | |||
| MGB_RECORD_EVENT( | |||
| TensorProduceEvent, info->id, info->desc.layout, info->desc.comp_node, | |||
| @@ -536,9 +534,6 @@ void ChannelImpl::init(TensorInfo* info, LogicalTensorDesc desc) { | |||
| MGB_RECORD_EVENT(TensorDeclareEvent, info->id, info->name); | |||
| info->status = TensorInfo::Allocated; | |||
| info->desc = std::move(desc); | |||
| info->mem_desc.layout = info->desc.layout; | |||
| info->mem_desc.cn = info->desc.comp_node; | |||
| info->mem_desc.offset = 0; | |||
| } | |||
| void ChannelImpl::do_drop(TensorInfo* ptr, bool user = false) { | |||
| @@ -667,18 +662,14 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| bool profiling_device = | |||
| Profiler::is_profiling() && Profiler::get_option("profile_device", 0); | |||
| uint64_t apply_id = cmd.id; | |||
| struct TensorWithDesc { | |||
| TensorPtr tensor; | |||
| MemoryDesc desc; | |||
| }; | |||
| SmallVector<TensorWithDesc> inputs; | |||
| SmallVector<TensorPtr> inputs; | |||
| inputs.reserve(cmd.inputs.size()); | |||
| // refcnt == 1, owners: [TensorInfo::ptr] | |||
| for (auto i : cmd.inputs) { | |||
| mgb_assert(i->ptr, "Invalid input tensor ptr!"); | |||
| // refcnt ++, owners: [i->ptr, tensor_inputs] | |||
| // tensor_inputs.push_back(i->ptr); | |||
| inputs.push_back({i->ptr, i->mem_desc}); | |||
| inputs.push_back(i->ptr); | |||
| } | |||
| if (state.options.enable_dtr_auto_drop && | |||
| state.options.dtr_eviction_threshold > 0) { | |||
| @@ -686,56 +677,28 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| } | |||
| auto apply_on_physical_tensor = | |||
| [&](auto&& self, const OpDef& def, | |||
| SmallVector<TensorWithDesc> inputs) -> SmallVector<TensorWithDesc> { | |||
| SmallVector<TensorPtr> inputs) -> SmallVector<TensorPtr> { | |||
| auto apply_functor = [&](std::shared_ptr<OpDef> op, | |||
| SmallVector<TensorWithDesc> inputs, | |||
| size_t nr_outputs) -> SmallVector<TensorWithDesc> { | |||
| SmallVector<TensorPtr> inputs, | |||
| size_t nr_outputs) -> SmallVector<TensorPtr> { | |||
| auto opname = op->trait()->make_name(*op); | |||
| imperative_log_profile_begin(opname.c_str()); | |||
| auto outputs = self(self, *op, inputs); | |||
| imperative_log_profile_end(opname.c_str()); | |||
| return outputs; | |||
| }; | |||
| auto const_functor = [&](TensorPtr value) -> TensorWithDesc { | |||
| return {value, MemoryDesc{ | |||
| value->layout(), 0, value->comp_node(), | |||
| StorageIdentifier::make()}}; | |||
| }; | |||
| auto const_functor = [&](TensorPtr value) -> TensorPtr { return value; }; | |||
| if (def.trait()->make_forward_graph) { | |||
| // apply recursivily | |||
| SmallVector<LogicalTensorDesc> input_descs; | |||
| for (auto&& input : inputs) { | |||
| input_descs.push_back( | |||
| {{{}, input.tensor->dtype()}, input.tensor->comp_node()}); | |||
| input_descs.push_back({{{}, input->dtype()}, input->comp_node()}); | |||
| } | |||
| auto forward_graph = OpDef::make_forward_graph(def, input_descs); | |||
| auto outputs = forward_graph.apply(inputs, apply_functor, const_functor); | |||
| return outputs; | |||
| } | |||
| SmallVector<TensorPtr> input_tensors; | |||
| SmallVector<MemoryDesc> input_descs; | |||
| for (auto&& input : inputs) { | |||
| input_tensors.push_back(input.tensor); | |||
| input_descs.push_back(input.desc); | |||
| } | |||
| auto [output_descs, output_tensors, workspaces] = | |||
| init_output_and_workspace(def, input_tensors, input_descs); | |||
| if (!output_descs.empty()) { | |||
| OpDef::execute(def, input_tensors, output_tensors, workspaces); | |||
| } else { | |||
| output_tensors = OpDef::apply_on_physical_tensor(def, input_tensors); | |||
| for (auto&& output_tensor : output_tensors) { | |||
| output_descs.push_back(MemoryDesc{ | |||
| output_tensor->layout(), 0, output_tensor->comp_node(), | |||
| StorageIdentifier::make()}); | |||
| } | |||
| } | |||
| SmallVector<TensorWithDesc> outputs; | |||
| for (auto&& [output_tensor, output_desc] : | |||
| ranges::zip_view(output_tensors, output_descs)) { | |||
| outputs.push_back({output_tensor, output_desc}); | |||
| } | |||
| return outputs; | |||
| return OpDef::apply_on_physical_tensor(def, inputs); | |||
| }; | |||
| MGB_RECORD_EVENT(OpExecuteEvent, apply_id, {}, reason); | |||
| // Begin profiling operator | |||
| @@ -787,8 +750,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); | |||
| } else { | |||
| MGB_RECORD_EVENT(OpOutputEvent, output->id); | |||
| produce_tensor(output, outputs[i].tensor); | |||
| output->mem_desc = outputs[i].desc; | |||
| produce_tensor(output, outputs[i]); | |||
| MGB_RECORD_EVENT(OpOutputFinishEvent, output->id); | |||
| sample_on_device(output->desc.comp_node, false); | |||
| } | |||
| @@ -800,7 +762,7 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| estimate_compute_time += i->memory; | |||
| } | |||
| for (auto i : outputs) { | |||
| estimate_compute_time += i.tensor->blob()->size(); | |||
| estimate_compute_time += i->blob()->size(); | |||
| } | |||
| m_dtr.estimate_timestamp += estimate_compute_time / 1e8; | |||
| for (auto i : cmd.outputs) { | |||
| @@ -1012,52 +974,6 @@ void ChannelImpl::alloc_tensor_with_evict(Blob* x) { | |||
| set_log_level(pre_level); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> | |||
| ChannelImpl::init_output_and_workspace( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<MemoryDesc> inputs_mem_desc) { | |||
| auto [outputs_desc, workspaces_desc] = | |||
| OpDef::infer_output_mem_desc(def, inputs, inputs_mem_desc); | |||
| if (!outputs_desc.size()) { | |||
| // failed to infer memplan | |||
| return {{}, {}, {}}; | |||
| } | |||
| // refine storage id to make it unique | |||
| for (auto&& desc : outputs_desc) { | |||
| if (desc.id->is_sys_alloc()) { | |||
| // TODO: there may be some outputs sharing the same storage id | |||
| desc.id->id = ++m_storage_id; | |||
| } | |||
| } | |||
| auto& state = get_worker_state(); | |||
| auto alloc_storage = [&](SmallVector<MemoryDesc>& desc) { | |||
| SmallVector<TensorPtr> tensors; | |||
| for (size_t i = 0; i < desc.size(); i++) { | |||
| if (desc[i].id->is_sys_alloc()) { | |||
| tensors.push_back(Tensor::make(desc[i].layout, desc[i].cn)); | |||
| if (state.options.enable_dtr_auto_drop && !desc[i].layout.is_empty()) { | |||
| alloc_tensor_with_evict(tensors.back()->blob().get()); | |||
| } | |||
| } else if (desc[i].id->is_from_other()) { | |||
| for (size_t j = 0; j < inputs_mem_desc.size(); j++) { | |||
| if (inputs_mem_desc[j].id->desc == desc[i].id->desc) { | |||
| tensors.push_back( | |||
| inputs[j]->sub(desc[i].offset, desc[i].layout)); | |||
| break; | |||
| } | |||
| } | |||
| } else if (desc[i].id->is_device_ptr()) { | |||
| tensors.push_back(desc[i].id->ptr); | |||
| } else { | |||
| mgb_assert(0, "not implemented"); | |||
| } | |||
| } | |||
| return tensors; | |||
| }; | |||
| return {outputs_desc, alloc_storage(outputs_desc), alloc_storage(workspaces_desc)}; | |||
| } | |||
| void ChannelImpl::process_one_task(Command& icmd) { | |||
| using namespace ranges; | |||
| using namespace ranges::views; | |||
| @@ -105,11 +105,6 @@ private: | |||
| void flush_apply_stack(); | |||
| void do_apply_op(const ApplyOp& cmd, std::string reason); | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<TensorPtr>, SmallVector<TensorPtr>> | |||
| init_output_and_workspace( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<MemoryDesc> inputs_mem_desc); | |||
| void dispatch_default_cpu( | |||
| std::shared_ptr<OpDef> op, const SmallVector<TensorInfo*>& input_infos, | |||
| const SmallVector<LogicalTensorDesc>& input_descs, | |||
| @@ -296,6 +291,8 @@ private: | |||
| op_blacklist.end(); | |||
| } | |||
| // operators that cannot be re-computed, including : | |||
| // distributed operators, inplace operator, random generator operators | |||
| std::vector<std::string> op_blacklist = { | |||
| "CollectiveComm", "InplaceAdd", "ParamPackSplit", "ParamPackConcat", | |||
| "GaussianRNG", "UniformRNG", "GammaRNG", "PermutationRNG", | |||
| @@ -59,7 +59,6 @@ struct TensorInfo { | |||
| // Lock interpreter when visiting `ptr`. | |||
| TensorPtr ptr; | |||
| LogicalTensorDesc desc; | |||
| MemoryDesc mem_desc; | |||
| double compute_time; | |||
| size_t memory; | |||
| @@ -41,20 +41,6 @@ SmallVector<TensorPtr> OpDef::apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| return def.trait()->apply_on_physical_tensor(def, std::move(inputs)); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> OpDef:: | |||
| infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| return def.trait()->infer_output_mem_desc(def, inputs_tensors, inputs_mems); | |||
| } | |||
| void OpDef::execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| def.trait()->execute(def, std::move(inputs), outputs, std::move(workspace)); | |||
| } | |||
| void OpDef::apply_on_device_tensornd( | |||
| const OpDef& def, const SmallVector<DeviceTensorND>& inputs, | |||
| SmallVector<DeviceTensorND>* outputs) { | |||
| @@ -43,13 +43,6 @@ void OpMethFallbackByProxyGraph::impl( | |||
| ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) { | |||
| func.Base::operator=(proxy_graph_detail::apply_on_physical_tensor); | |||
| } | |||
| void OpMethFallbackByProxyGraph::impl(Execute& func, op_meth_tag::Execute) { | |||
| func.Base::operator=(proxy_graph_detail::execute); | |||
| } | |||
| void OpMethFallbackByProxyGraph::impl( | |||
| InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) { | |||
| func.Base::operator=(proxy_graph_detail::infer_output_mem_desc); | |||
| } | |||
| void OpMethFallbackByProxyGraph::impl( | |||
| InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible) { | |||
| func.Base::operator=(proxy_graph_detail::infer_output_attrs_fallible); | |||
| @@ -62,10 +55,6 @@ void OpMethFallbackFromSubgraph::impl( | |||
| ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor) { | |||
| func.Base::operator=(subgraph_detail::apply_on_physical_tensor); | |||
| } | |||
| void OpMethFallbackFromSubgraph::impl( | |||
| InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc) { | |||
| func.Base::operator=(subgraph_detail::infer_output_mem_desc); | |||
| } | |||
| void OpMethFallbackFromSubgraph::impl( | |||
| ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode) { | |||
| func.Base::operator=(subgraph_detail::apply_on_var_node); | |||
| @@ -64,12 +64,6 @@ OpMethType(DecideDispatchMode, | |||
| OpMethType(ApplyOnPhysicalTensor, | |||
| decltype(OpDef::apply_on_physical_tensor)); | |||
| OpMethType(InferOutputMemDesc, | |||
| decltype(OpDef::infer_output_mem_desc)); | |||
| OpMethType(Execute, | |||
| decltype(OpDef::execute)); | |||
| OpMethType(ApplyOnDeviceTensorND, | |||
| decltype(OpDef::apply_on_device_tensornd)); | |||
| @@ -123,8 +117,6 @@ struct OpMethFallback : OpMethImplBase { | |||
| struct OpMethFallbackByProxyGraph : OpMethImplBase { | |||
| using OpMethImplBase::impl; | |||
| static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); | |||
| static void impl(Execute& func, op_meth_tag::Execute); | |||
| static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); | |||
| static void impl( | |||
| InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); | |||
| static void impl(GradMaker& func, op_meth_tag::GradMaker); | |||
| @@ -133,7 +125,6 @@ struct OpMethFallbackByProxyGraph : OpMethImplBase { | |||
| struct OpMethFallbackFromSubgraph : OpMethImplBase { | |||
| using OpMethImplBase::impl; | |||
| static void impl(ApplyOnPhysicalTensor& func, op_meth_tag::ApplyOnPhysicalTensor); | |||
| static void impl(InferOutputMemDesc& func, op_meth_tag::InferOutputMemDesc); | |||
| static void impl(ApplyOnVarNode& func, op_meth_tag::ApplyOnVarNode); | |||
| static void impl( | |||
| InferOutputAttrsFallible& func, op_meth_tag::InferOutputAttrsFallible); | |||
| @@ -185,8 +176,6 @@ struct OpTrait { | |||
| OpDefMaker make_from_op_node; | |||
| DecideDispatchMode decide_dispatch_mode; | |||
| ApplyOnPhysicalTensor apply_on_physical_tensor; | |||
| InferOutputMemDesc infer_output_mem_desc; | |||
| Execute execute; | |||
| ApplyOnDeviceTensorND apply_on_device_tensornd; | |||
| ApplyOnVarNode apply_on_var_node; | |||
| InferOutputAttrsFallible infer_output_attrs_fallible; | |||
| @@ -207,8 +196,6 @@ struct OpTrait { | |||
| cb(make_from_op_node) \ | |||
| cb(decide_dispatch_mode) \ | |||
| cb(apply_on_physical_tensor) \ | |||
| cb(infer_output_mem_desc) \ | |||
| cb(execute) \ | |||
| cb(apply_on_device_tensornd) \ | |||
| cb(apply_on_var_node) \ | |||
| cb(infer_output_attrs_fallible) \ | |||
| @@ -81,50 +81,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto& input = inputs_tensors[0]; | |||
| TensorShape target_shape; | |||
| cg::copy_tensor_value_to_shape( | |||
| target_shape, inputs_tensors[1]->get_value().proxy_to_default_cpu()); | |||
| // TODO: memory forward | |||
| // if (input->shape().eq_shape(target_shape)) { | |||
| // return {{{input->layout(), 0, input->comp_node(), | |||
| // StorageIdentifier::make(&inputs_mems[0])}}, {}}; | |||
| // } | |||
| return {{{{target_shape, input->dtype()}, | |||
| 0, | |||
| input->comp_node(), | |||
| StorageIdentifier::make(0)}}, | |||
| {}}; | |||
| } | |||
| void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| if (outputs[0]->layout().is_empty()) { | |||
| return; | |||
| } | |||
| if (inputs[0]->shape().eq_shape(outputs[0]->shape())) { | |||
| mgb_assert(inputs[0]->layout().eq_layout(outputs[0]->layout())); | |||
| // TODO: memory forward | |||
| // mgb_assert(inputs[0]->offset() == outputs[0]->offset()); | |||
| // mgb_assert(inputs[0]->blob() == outputs[0]->blob()); | |||
| outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor()); | |||
| } else { | |||
| TensorLayout input_layout = inputs[0]->layout().broadcast(outputs[0]->shape()); | |||
| outputs[0]->dev_tensor().copy_from_fixlayout(inputs[0]->dev_tensor().sub( | |||
| SubTensorSpec::make_from_layout(input_layout))); | |||
| } | |||
| } | |||
| OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .execute(execute) | |||
| .fallback(); | |||
| } // namespace broadcast | |||
| @@ -187,41 +147,9 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return {{{TensorLayout(out_shape, src.layout.dtype), src.comp_node}}, true}; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto&& op_def = def.cast_final_safe<Reshape>(); | |||
| size_t nr_inp = inputs.size(); | |||
| mgb_assert(nr_inp == 2, "Reshape expects 2 inputs; got %lu actually", nr_inp); | |||
| auto&& src = inputs[0]; | |||
| auto&& tshp_nd = inputs[1]; | |||
| auto slayout = src->layout(); | |||
| TensorShape tshp; | |||
| cg::copy_tensor_value_to_shape(tshp, tshp_nd->get_value().proxy_to_default_cpu()); | |||
| if (op_def.axis != opr::Reshape::Param::INVALID_AXIS) { | |||
| mgb_assert(tshp[op_def.axis] == -1); | |||
| tshp[op_def.axis] = 1; | |||
| tshp[op_def.axis] = src->layout().total_nr_elems() / tshp.total_nr_elems(); | |||
| } | |||
| TensorLayout tlayout = slayout.reshape(tshp); | |||
| // memory forward | |||
| return {{{tlayout, 0, src->comp_node(), StorageIdentifier::make(&inputs_mems[0])}}, | |||
| {}}; | |||
| } | |||
| void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| mgb_assert(inputs[0]->offset() == outputs[0]->offset()); | |||
| mgb_assert(inputs[0]->blob() == outputs[0]->blob()); | |||
| } | |||
| OP_TRAIT_REG(Reshape, Reshape) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .execute(execute) | |||
| .fallback(); | |||
| } // namespace reshape | |||
| @@ -78,25 +78,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| false}; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| return {{}, {}}; | |||
| } | |||
| void execute( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs, | |||
| const SmallVector<TensorPtr>& workspace) { | |||
| mgb_assert(0); | |||
| } | |||
| OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .execute(execute) | |||
| .fallback(); | |||
| } // namespace | |||
| @@ -234,12 +234,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return op.infer_output_attrs(inputs); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| return {{}, {}}; | |||
| } | |||
| size_t hash(const OpDef& def) { | |||
| auto&& op = static_cast<const CustomOpDef&>(def); | |||
| const custom::Param& param = op.param(); | |||
| @@ -279,7 +273,6 @@ OP_TRAIT_REG(CustomOpDef, CustomOpDef) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .hash(hash) | |||
| .is_same_st(is_same_st) | |||
| .props(props) | |||
| @@ -110,35 +110,6 @@ void apply_on_device_tensornd( | |||
| opr::Elemwise::perform(op_def.mode, (*outputs)[0], inputs, dnn_opr); | |||
| } | |||
| void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| mgb_assert(outputs.size() == 1); | |||
| SmallVector<DeviceTensorND> inp_tensornds(inputs.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| inp_tensornds[i] = inputs[i]->dev_tensor(); | |||
| } | |||
| SmallVector<DeviceTensorND> out_tensornds = {outputs[0]->dev_tensor()}; | |||
| apply_on_device_tensornd(def, inp_tensornds, &out_tensornds); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
| TensorShapeArray inp_shapes(inputs_tensors.size()); | |||
| for (size_t i = 0; i < inputs_tensors.size(); ++i) { | |||
| inp_shapes[i] = inputs_tensors[i]->layout(); | |||
| } | |||
| TensorShape shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); | |||
| SmallVector<MemoryDesc> outputs = { | |||
| {{shape, inputs_tensors[0]->dtype()}, | |||
| 0, | |||
| inputs_tensors[0]->comp_node(), | |||
| StorageIdentifier::make(1)}}; | |||
| return {outputs, {}}; | |||
| } | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
| @@ -251,7 +222,7 @@ cg::OperatorNodeBase* apply_inplace_add_on_var_node( | |||
| SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| mgb_assert( | |||
| inputs[0]->blob().use_count() == 2 && inputs[0]->blob()->storage().unique(), | |||
| inputs[0]->blob().use_count() == 1 && inputs[0]->blob()->storage().unique(), | |||
| "This inplace modification may change the elements of other tensors. " | |||
| "Please set MEGENGINE_INPLACE_UPDATE to 0 to ensure the program runs " | |||
| "correctly."); | |||
| @@ -265,23 +236,6 @@ SmallVector<TensorPtr> apply_inplace_add_on_physical_tensor( | |||
| return {std::make_shared<Tensor>(dest->blob(), dest->offset(), dest->layout())}; | |||
| } | |||
| void execute_inplace( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| apply_inplace_add_on_physical_tensor(def, inputs); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||
| infer_inplace_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto dest = inputs_tensors[0]; | |||
| SmallVector<MemoryDesc> outputs = { | |||
| {dest->layout(), 0, dest->comp_node(), | |||
| StorageIdentifier::make(&inputs_mems[0])}}; | |||
| return {outputs, {}}; | |||
| } | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_inplace_add_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) { | |||
| mgb_assert(inputs.size() == 4, "invalid input number for inplace_add"); | |||
| @@ -319,16 +273,12 @@ OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .execute(execute) | |||
| .fallback(); | |||
| OP_TRAIT_REG(InplaceAdd, InplaceAdd, opr::AddUpdate) | |||
| .apply_on_var_node(apply_inplace_add_on_var_node) | |||
| .apply_on_physical_tensor(apply_inplace_add_on_physical_tensor) | |||
| .infer_output_attrs_fallible(infer_inplace_add_output_attrs_fallible) | |||
| .infer_output_mem_desc(infer_inplace_output_mem_desc) | |||
| .execute(execute_inplace) | |||
| .fallback(); | |||
| } // anonymous namespace | |||
| @@ -75,16 +75,11 @@ SmallVector<LogicalTensorDesc> infer_output_attrs( | |||
| dests[size].layout = TensorLayout(TensorShape({1}), dtype::Int32()); | |||
| return dests; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| return {{}, {}}; | |||
| } | |||
| OP_TRAIT_REG(CheckNonFinite, CheckNonFinite) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .fallback(); | |||
| } // namespace check_non_finite | |||
| @@ -36,6 +36,7 @@ std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| return Reduce::make(node->param()); | |||
| } | |||
| // TODO: using this for apply_on_physical_tensor | |||
| bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| auto&& reduce = static_cast<const Reduce&>(def); | |||
| if (reduce.mode != Reduce::Mode::SUM_SQR && inputs.size() == 2) { | |||
| @@ -49,31 +50,9 @@ bool memory_forward_success(const OpDef& def, SmallVector<TensorPtr> inputs) { | |||
| return false; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| if (memory_forward_success(def, inputs_tensors)) { | |||
| auto& src_desc = inputs_mems[0]; | |||
| return {{{src_desc.layout, 0, src_desc.cn, StorageIdentifier::make(&src_desc)}}, | |||
| {}}; | |||
| } | |||
| return proxy_graph_detail::infer_output_mem_desc(def, inputs_tensors, inputs_mems); | |||
| } | |||
| void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| if (memory_forward_success(def, inputs)) { | |||
| return; | |||
| } | |||
| return proxy_graph_detail::execute(def, inputs, outputs, workspace); | |||
| } | |||
| OP_TRAIT_REG(Reduce, Reduce, opr::Reduce) | |||
| .make_from_op_node(make_from_op_node) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .execute(execute) | |||
| .fallback(); | |||
| } // namespace reduce | |||
| } // namespace | |||
| @@ -517,20 +517,6 @@ SmallVector<LogicalTensorDesc> infer_output_attrs<Dropout>( | |||
| return dests; | |||
| } | |||
| template <typename Op> | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto&& dests = infer_output_attrs<Op>(def, inputs_tensors); | |||
| SmallVector<MemoryDesc> outputs; | |||
| for (size_t i = 0; i < dests.size(); ++i) { | |||
| outputs.push_back( | |||
| {dests[i].layout, 0, dests[i].comp_node, | |||
| StorageIdentifier::make(i + 1)}); | |||
| } | |||
| return {outputs, {}}; | |||
| } | |||
| template <typename Op> | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| @@ -543,13 +529,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| return outputs; | |||
| } | |||
| template <typename Op> | |||
| void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| exec<Op>(def, inputs, outputs, {}); | |||
| } | |||
| template <typename Op, typename Output> | |||
| Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| size_t nr_inp = inputs.size(); | |||
| @@ -641,8 +620,6 @@ CompNode get_rng_handle_compnode(Handle handle) { | |||
| .apply_on_var_node(apply_on_var_node<NAME, Output>) \ | |||
| .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ | |||
| .infer_output_mem_desc(infer_output_mem_desc<NAME>) \ | |||
| .execute(execute<NAME>) \ | |||
| .fallback(); \ | |||
| } | |||
| @@ -141,39 +141,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| HostTensorND tensor = get_var_shape_host_tensor(def, inputs); | |||
| SmallVector<MemoryDesc> ret; | |||
| auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor); | |||
| if (blob) { | |||
| ret.push_back( | |||
| {tensor.layout(), 0, inputs[0]->comp_node(), | |||
| StorageIdentifier::make(Tensor::make( | |||
| std::forward<decltype(blob)>(blob), tensor.layout(), | |||
| tensor))}); | |||
| } else { | |||
| ret.push_back( | |||
| {tensor.layout(), 0, inputs[0]->comp_node(), | |||
| StorageIdentifier::make(1)}); | |||
| } | |||
| return {ret, {}}; | |||
| } | |||
| void execute( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs, | |||
| const SmallVector<TensorPtr>& workspace) { | |||
| HostTensorND tensor = get_var_shape_host_tensor(def, inputs); | |||
| SmallVector<MemoryDesc> ret; | |||
| auto&& blob = MultiCNConstTensorCache::inst().lookup(tensor); | |||
| if (!blob || blob->storage() != outputs[0]->blob()->storage()) { | |||
| outputs[0]->dev_tensor().copy_from_fixlayout(tensor); | |||
| AsyncReleaser::inst()->add(tensor); | |||
| } | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| auto* node = &node_->cast_final_safe<opr::GetVarShape>(); | |||
| return GetVarShape::make(node->param()); | |||
| @@ -186,8 +153,6 @@ OP_TRAIT_REG(GetVarShape, GetVarShape, opr::GetVarShape) | |||
| .apply_on_var_node(apply_on_var_node) | |||
| .apply_on_device_tensornd(apply_on_device_tensornd) | |||
| .apply_on_physical_tensor(apply_on_physical_tensor) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .execute(execute) | |||
| .fallback(); | |||
| } // namespace get_var_shape | |||
| @@ -215,38 +180,6 @@ cg::OperatorNodeBase* param_pack_split_apply_on_var_node( | |||
| return opr; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||
| param_pack_split_infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||
| mgb_assert( | |||
| inputs.size() == 1, "ParamPackSplit take 1 input, got %lu", inputs.size()); | |||
| auto&& inp = inputs[0]; | |||
| auto&& shp = inp->layout(); | |||
| mgb_assert(shp.ndim == 1, "ParamPackSplit input shape invalid, ndim should be 1"); | |||
| mgb_assert(param.shapes.size() * 2 == param.offsets.size()); | |||
| SmallVector<MemoryDesc> ret; | |||
| auto&& shapes = get_shapes(param.shapes); | |||
| size_t dtype_size = inputs[0]->layout().dtype.size(); | |||
| for (size_t i = 0; i < shapes.size(); ++i) { | |||
| // memory forward | |||
| ret.push_back( | |||
| {{shapes[i], inputs[0]->dtype()}, | |||
| param.offsets[i * 2] * dtype_size, | |||
| inp->comp_node(), | |||
| StorageIdentifier::make(&inputs_mems[0])}); | |||
| } | |||
| return {ret, {}}; | |||
| } | |||
| void param_pack_split_execute( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs, | |||
| const SmallVector<TensorPtr>& workspace) { | |||
| // do nothing | |||
| } | |||
| SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| auto&& param = def.cast_final_safe<ParamPackSplit>(); | |||
| @@ -268,8 +201,6 @@ SmallVector<TensorPtr> param_pack_split_apply_on_physical_tensor( | |||
| OP_TRAIT_REG(ParamPackSplit, ParamPackSplit, mgb::opr::ParamPackSplit) | |||
| .apply_on_var_node(param_pack_split_apply_on_var_node) | |||
| .infer_output_mem_desc(param_pack_split_infer_output_mem_desc) | |||
| .execute(param_pack_split_execute) | |||
| .apply_on_physical_tensor(param_pack_split_apply_on_physical_tensor) | |||
| .fallback(); | |||
| @@ -286,75 +217,6 @@ cg::OperatorNodeBase* param_pack_concat_apply_on_var_node( | |||
| return opr; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||
| param_pack_concat_infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| def.cast_final_safe<ParamPackConcat>(); | |||
| mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); | |||
| auto comp_node = inputs.front()->comp_node(); | |||
| auto dtype = inputs.front()->dtype(); | |||
| size_t nr_inputs = inputs.size() - 1; | |||
| size_t nr_elems = 0; | |||
| for (size_t i = 0; i < nr_inputs; ++i) { | |||
| auto& input = inputs[i]; | |||
| mgb_assert( | |||
| comp_node == input->comp_node(), | |||
| "inputs for param_pack_concat must in same comp_node"); | |||
| mgb_assert( | |||
| dtype == input->dtype(), | |||
| "inputs for param_pack_concat must have same dtype"); | |||
| nr_elems += input->layout().total_nr_elems(); | |||
| } | |||
| auto dest_layout = TensorLayout({nr_elems}, dtype); | |||
| auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node); | |||
| size_t ws_size; | |||
| { | |||
| TensorShapeArray src_shapes; | |||
| for (size_t i = 0; i < nr_inputs; ++i) { | |||
| src_shapes.push_back(inputs[i]->shape()); | |||
| } | |||
| ws_size = caller.op->get_workspace_in_bytes( | |||
| src_shapes, inputs.back()->shape(), TensorShape{}); | |||
| } | |||
| SmallVector<MemoryDesc> outputs = { | |||
| {dest_layout, 0, comp_node, StorageIdentifier::make(1)}}; | |||
| MemoryDesc workspace = { | |||
| {{ws_size}, dtype::Byte()}, 0, comp_node, StorageIdentifier::make(2)}; | |||
| return {outputs, {workspace}}; | |||
| } | |||
| void param_pack_concat_execute( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs, | |||
| const SmallVector<TensorPtr>& workspace) { | |||
| def.cast_final_safe<ParamPackConcat>(); | |||
| mgb_assert(inputs.size() > 1, "param_pack should have at least one input"); | |||
| auto comp_node = inputs.front()->comp_node(); | |||
| size_t nr_inputs = inputs.size() - 1; | |||
| auto caller = DnnOprCaller<megdnn::ParamPackConcat>(comp_node); | |||
| size_t srcs_size = sizeof(void*) * nr_inputs; | |||
| void** srcs_raw_ptr = (void**)comp_node.alloc_host(srcs_size); | |||
| std::shared_ptr<dt_byte> srcs_ptr = { | |||
| (dt_byte*)srcs_raw_ptr, | |||
| [comp_node](dt_byte* ptr) { comp_node.free_host(ptr); }}; | |||
| TensorLayout srcs_layout = TensorLayout{{nr_inputs}, dtype::Int32()}; | |||
| for (size_t i = 0; i < nr_inputs; ++i) { | |||
| srcs_raw_ptr[i] = inputs[i]->dev_tensor().as_megdnn().raw_ptr(); | |||
| } | |||
| HostTensorStorage srcs_storage; | |||
| srcs_storage.reset(comp_node, srcs_size, srcs_ptr); | |||
| megdnn::Workspace dnn_wk( | |||
| workspace[0]->blob()->storage().get(), workspace[0]->blob()->size()); | |||
| caller.op->exec( | |||
| {srcs_raw_ptr, srcs_layout}, inputs.back()->dev_tensor().as_megdnn(), | |||
| outputs[0]->dev_tensor().as_megdnn(), dnn_wk); | |||
| AsyncReleaser::inst()->add( | |||
| HostTensorND{comp_node, srcs_layout}.storage(srcs_storage)); | |||
| } | |||
| SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| def.cast_final_safe<ParamPackConcat>(); | |||
| @@ -407,8 +269,6 @@ SmallVector<TensorPtr> param_pack_concat_apply_on_physical_tensor( | |||
| OP_TRAIT_REG(ParamPackConcat, ParamPackConcat, mgb::opr::ParamPackConcat) | |||
| .apply_on_var_node(param_pack_concat_apply_on_var_node) | |||
| .infer_output_mem_desc(param_pack_concat_infer_output_mem_desc) | |||
| .execute(param_pack_concat_execute) | |||
| .apply_on_physical_tensor(param_pack_concat_apply_on_physical_tensor) | |||
| .fallback(); | |||
| } // namespace param_pack | |||
| @@ -445,12 +445,6 @@ auto make_name(const OpDef& def) { | |||
| return ssprintf("CompiledOp[%s]", op.op->make_name().c_str()); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| return {}; | |||
| } | |||
| EncodedSubgraph make_backward_graph( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| @@ -498,7 +492,6 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) | |||
| .infer_output_attrs_fallible(infer_output_attrs_fallible) | |||
| .make_backward_graph(make_backward_graph) | |||
| .make_name(make_name) | |||
| .infer_output_mem_desc(infer_output_mem_desc) | |||
| .props(props) | |||
| .hash(hash) | |||
| .is_same_st(is_same_st) | |||
| @@ -634,36 +634,6 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph:: | |||
| mgb_assert(0); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> ProxyGraph:: | |||
| infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<Tensor*>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto opr = get_proxy_opr(def, inputs_tensors); | |||
| CUR_OPR_GUARD(opr); | |||
| ::mgb::opr::intl::WorkspaceLimitHook::set_impl( | |||
| m_graph.get(), ProxyGraph::get_workspace_limit); | |||
| do_shape_infer(true); | |||
| SmallVector<MemoryDesc> outputs; | |||
| SmallVector<MemoryDesc> workspaces; | |||
| size_t cur_id = 0; | |||
| for (auto&& i : opr->output()) { | |||
| if (i->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { | |||
| workspaces.push_back( | |||
| {{i->shape(), i->dtype(), i->format()}, | |||
| 0, | |||
| i->comp_node(), | |||
| StorageIdentifier::make(++cur_id)}); | |||
| } else { | |||
| outputs.push_back( | |||
| {{i->shape(), i->dtype()}, | |||
| 0, | |||
| i->comp_node(), | |||
| StorageIdentifier::make(++cur_id)}); | |||
| } | |||
| } | |||
| return {outputs, workspaces}; | |||
| } | |||
| struct ProxyGraph::GradGraph { | |||
| cg::VarNodeArray inputs; | |||
| cg::VarNodeArray outputs; | |||
| @@ -812,7 +782,6 @@ EncodedSubgraph ProxyGraph::make_backward_graph( | |||
| return result; | |||
| } | |||
| VarNodeArray ProxyGraph::make_input_place_holders( | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| VarNodeArray vinputs(inputs.size()); | |||
| @@ -47,10 +47,6 @@ public: | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad); | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<Tensor*>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems); | |||
| /********************** Logical Tensor API **********************/ | |||
| size_t get_opr_output_size( | |||
| @@ -83,25 +83,6 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| return outputs; | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| auto&& graph = ProxyGraph::get_default_graph(); | |||
| return graph->infer_output_mem_desc( | |||
| def, to_raw_ptr_array(inputs_tensors), inputs_mems); | |||
| } | |||
| void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace) { | |||
| exec(def, inputs, outputs, workspace); | |||
| auto async_error = ProxyGraph::get_async_error(); | |||
| if (async_error) { | |||
| throw *async_error; | |||
| } | |||
| return; | |||
| } | |||
| // std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const | |||
| // OpDef& def, | |||
| // const SmallVector<LogicalTensorDesc>& inputs) { | |||
| @@ -162,12 +162,6 @@ EncodedSubgraph make_backward_graph( | |||
| inputs, input_requires_grad, output_has_grad, forward_graph); | |||
| } | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems) { | |||
| return {{}, {}}; | |||
| } | |||
| } // namespace subgraph_detail | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -53,10 +53,6 @@ public: | |||
| static SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs); | |||
| static void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, | |||
| SmallVector<TensorPtr> outputs, SmallVector<TensorPtr> workspace); | |||
| /*! | |||
| * \brief Call the corresponding dnn op to calculate results. Output | |||
| * tensors' device memory should be allocated outside. | |||
| @@ -71,11 +67,6 @@ public: | |||
| static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
| static std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> | |||
| infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems); | |||
| static EncodedSubgraph make_backward_graph( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs, | |||
| const SmallVector<bool>& input_requires_grad, | |||
| @@ -288,36 +288,6 @@ struct LogicalTensorDesc { | |||
| CompNode comp_node; | |||
| DeviceTensorND value; // cpu:default | |||
| }; | |||
| struct StorageIdentifier; | |||
| struct MemoryDesc { | |||
| TensorLayout layout; | |||
| size_t offset; | |||
| CompNode cn; | |||
| std::shared_ptr<StorageIdentifier> id; | |||
| }; | |||
| struct StorageIdentifier { | |||
| enum { INVALID, SYS_ALLOC, FROM_OTHER, DEVICE_PTR } tag; | |||
| union { | |||
| size_t id; | |||
| MemoryDesc* desc; | |||
| }; | |||
| TensorPtr ptr; | |||
| StorageIdentifier() = default; | |||
| StorageIdentifier(size_t id) : tag(SYS_ALLOC), id(id) {} | |||
| StorageIdentifier(const MemoryDesc* desc) : tag(FROM_OTHER), desc(desc->id->desc) {} | |||
| StorageIdentifier(TensorPtr dev_ptr) : tag(DEVICE_PTR), ptr(dev_ptr) {} | |||
| template <typename... Args> | |||
| static std::shared_ptr<StorageIdentifier> make(Args&&... args) { | |||
| return std::make_shared<StorageIdentifier>(std::forward<Args>(args)...); | |||
| } | |||
| bool is_sys_alloc() { return tag == SYS_ALLOC; } | |||
| bool is_from_other() { return tag == FROM_OTHER; } | |||
| bool is_device_ptr() { return tag == DEVICE_PTR; } | |||
| bool is_invalid() { return tag == INVALID; } | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -20,17 +20,9 @@ namespace proxy_graph_detail { | |||
| SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs); | |||
| void execute( | |||
| const OpDef& def, SmallVector<TensorPtr> inputs, SmallVector<TensorPtr> outputs, | |||
| SmallVector<TensorPtr> workspace); | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs); | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems); | |||
| void exec( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs, | |||
| const SmallVector<TensorPtr>& outputs); | |||
| @@ -35,10 +35,6 @@ EncodedSubgraph make_backward_graph( | |||
| const SmallVector<bool>& input_requires_grad, | |||
| const SmallVector<bool>& output_has_grad); | |||
| std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs_tensors, | |||
| const SmallVector<MemoryDesc>& inputs_mems); | |||
| } // namespace subgraph_detail | |||
| } // namespace imperative | |||
| } // namespace mgb | |||