GitOrigin-RevId: b2b32774ee
tags/v1.1.0
| @@ -63,15 +63,17 @@ SmallVector<void*> ChannelImpl::apply_op( | |||
| input_infos.push_back(info); | |||
| input_descs.push_back(info->desc); | |||
| } | |||
| auto output_descs = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| auto [output_descs, validated] = OpDef::infer_output_attrs_fallible(*op, input_descs); | |||
| ApplyOp cmd{std::move(op)}; | |||
| cmd.inputs = std::move(input_infos); | |||
| cmd.outputs.reserve(output_descs.size()); | |||
| SmallVector<void*> outputs; | |||
| bool is_fallible = false; | |||
| // FIXME: remove this check when op check is correct | |||
| bool validated_bkp = true; | |||
| for (auto&& desc : output_descs) { | |||
| if (desc.layout.ndim == 0) { | |||
| is_fallible = true; | |||
| validated_bkp = false; | |||
| } | |||
| auto info = alloc(); | |||
| info->desc = desc; | |||
| @@ -80,8 +82,14 @@ SmallVector<void*> ChannelImpl::apply_op( | |||
| outputs.push_back(info); | |||
| } | |||
| m_worker.add_task(std::move(cmd)); | |||
| if (is_fallible && m_async_level <= 1) { | |||
| if (!(validated && validated_bkp) && m_async_level == 1) { | |||
| sync(); | |||
| } else if (m_async_level == 0) { | |||
| sync(); | |||
| // check device error | |||
| for (auto&& oup : cmd.outputs) { | |||
| oup->ptr->comp_node().sync(); | |||
| } | |||
| } | |||
| return outputs; | |||
| } | |||
| @@ -194,6 +202,9 @@ ChannelImpl::~ChannelImpl() { | |||
| void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| dest->value_fetched = ptr->value_fetched(); | |||
| // update tensor desc for static infer | |||
| dest->desc.layout = ptr->layout(); | |||
| dest->desc.comp_node = ptr->comp_node(); | |||
| dest->ptr = std::move(ptr); | |||
| if (m_waitee == dest) { | |||
| m_cv.notify_all(); | |||
| @@ -42,7 +42,7 @@ cg::OperatorNodeBase* OpDef::apply_on_var_node( | |||
| return def.trait()->apply_on_var_node(def, inputs); | |||
| } | |||
| SmallVector<LogicalTensorDesc> OpDef::infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> OpDef::infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| return def.trait()->infer_output_attrs_fallible(def, inputs); | |||
| @@ -24,12 +24,12 @@ BackwardGraph::InternalGraph::apply( | |||
| inputs); | |||
| } | |||
| SmallVector<LogicalTensorDesc> | |||
| BackwardGraph::InternalGraph::infer_attrs( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs( | |||
| const SmallVector<LogicalTensorDesc>& inputs) const { | |||
| using TensorAttr = LogicalTensorDesc; | |||
| ThinHashMap<size_t, TensorAttr> node2attr; | |||
| auto&& input_nodes = this->inputs; | |||
| auto&& output_nodes = this->outputs; | |||
| mgb_assert(inputs.size() == input_nodes.size()); | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| node2attr[input_nodes[i]] = inputs[i]; | |||
| @@ -41,25 +41,29 @@ BackwardGraph::InternalGraph::infer_attrs( | |||
| i.second->layout(), i.second->comp_node(), | |||
| value->proxy_to_default_cpu()}; | |||
| } | |||
| bool validated = true; | |||
| for (size_t i = 0; i < exprs.size(); ++ i) { | |||
| auto&& expr = exprs[i]; | |||
| SmallVector<TensorAttr> inputs; | |||
| for (auto &&in : std::get<1>(expr)) { | |||
| inputs.push_back(node2attr.at(in)); | |||
| auto&& [expr_op, expr_inps, expr_oups] = exprs[i]; | |||
| SmallVector<TensorAttr> expr_input_descs; | |||
| for (auto &&inp : expr_inps) { | |||
| expr_input_descs.push_back(node2attr.at(inp)); | |||
| } | |||
| auto outputs = OpDef::infer_output_attrs_fallible( | |||
| *std::get<0>(expr), inputs); | |||
| auto output_nodes = std::get<2>(expr); | |||
| mgb_assert(outputs.size() == output_nodes.size()); | |||
| for (size_t i = 0; i < outputs.size(); ++ i) { | |||
| node2attr[output_nodes[i]] = outputs[i]; | |||
| auto[expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( | |||
| *expr_op, expr_input_descs); | |||
| validated = validated && expr_validated; | |||
| mgb_assert(expr_output_descs.size() == expr_oups.size()); | |||
| for (size_t i = 0; i < expr_output_descs.size(); ++ i) { | |||
| node2attr[expr_oups[i]] = expr_output_descs[i]; | |||
| } | |||
| } | |||
| SmallVector<TensorAttr> ret; | |||
| for (auto &&i : outputs) { | |||
| for (auto &&i : output_nodes) { | |||
| ret.push_back(node2attr.at(i)); | |||
| } | |||
| return ret; | |||
| return {ret, validated}; | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); | |||
| @@ -72,11 +76,11 @@ SmallVector<TensorPtr> backward_impl( | |||
| .graph().apply(tensors); | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_tensor_attrs( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs( | |||
| const OpDef& backward_graph, | |||
| const SmallVector<LogicalTensorDesc> inputs) { | |||
| return backward_graph.cast_final_safe<BackwardGraph>() | |||
| .graph().infer_attrs(inputs); | |||
| .graph().infer_attrs(inputs); | |||
| } | |||
| OP_TRAIT_REG(BackwardGraph, BackwardGraph) | |||
| @@ -44,7 +44,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| } | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op_def = def.cast_final_safe<BatchNorm>(); | |||
| @@ -66,7 +66,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| out_shapes[i] = {i1.layout, i1.comp_node}; | |||
| } | |||
| out_shapes[nr_out-1] = {i0.layout, i0.comp_node}; | |||
| return out_shapes; | |||
| return {out_shapes, true}; | |||
| } | |||
| OP_TRAIT_REG(BatchNorm, BatchNorm, opr::BatchNorm) | |||
| @@ -47,7 +47,7 @@ bool valid_broadcast(const TensorShape& src_shape, | |||
| return true; | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| def.cast_final_safe<Broadcast>(); | |||
| @@ -59,7 +59,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| TensorLayout out_layout = src.layout; | |||
| if (tshp.layout.ndim == 0 || tshp.value.empty()) { | |||
| out_layout.ndim = 0; | |||
| return {{out_layout, src.comp_node}}; | |||
| return {{{out_layout, src.comp_node}}, true}; | |||
| } | |||
| mgb_assert( | |||
| tshp.layout.ndim == 1, | |||
| @@ -77,7 +77,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| src.layout.TensorShape::to_string().c_str(), | |||
| out_layout.TensorShape::to_string().c_str()); | |||
| return {{out_layout, src.comp_node}}; | |||
| return {{{out_layout, src.comp_node}}, true}; | |||
| } | |||
| OP_TRAIT_REG(Broadcast, Broadcast, opr::Broadcast) | |||
| @@ -25,7 +25,7 @@ namespace { | |||
| class MegDNNDynOutMallocImpl final: public megdnn::DynOutMallocPolicy { | |||
| using Output = std::array<TensorPtr, 2>; | |||
| CompNode m_cn; | |||
| Output m_out; | |||
| @@ -110,14 +110,14 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| return out; | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto cn = inputs[0].comp_node; | |||
| return { | |||
| return {{ | |||
| {TensorLayout(inputs[0].layout.dtype), cn}, | |||
| {TensorLayout(dtype::Int32()), cn} | |||
| }; | |||
| }, true}; | |||
| } | |||
| OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | |||
| @@ -128,4 +128,4 @@ OP_TRAIT_REG(CondTake, CondTake, opr::CondTake) | |||
| } // namespace | |||
| } // namespace mgb::imperative | |||
| } // namespace mgb::imperative | |||
| @@ -29,7 +29,7 @@ cg::OperatorNodeBase* apply_on_var_node( | |||
| return opr::Elemwise::make(inputs, elemwise_opr.mode).node()->owner_opr(); | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
| @@ -40,7 +40,7 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| TensorShapeArray inp_shapes; | |||
| DType out_dt; | |||
| CompNode out_cn; | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| for (size_t i = 0; i < inputs.size(); ++ i) { | |||
| auto &&t = inputs[i]; | |||
| if (!i) { | |||
| out_cn = t.comp_node; | |||
| @@ -55,12 +55,12 @@ SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| TensorLayout out_layout; | |||
| out_layout.ndim = 0; | |||
| out_layout.dtype = out_dt; | |||
| return {{out_layout, out_cn}}; | |||
| return {{{out_layout, out_cn}}, true}; | |||
| } | |||
| } | |||
| auto&& out_shape = opr::Elemwise::get_output_var_shape(op_def.mode, inp_shapes); | |||
| return {{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}; | |||
| return {{{TensorLayout(out_shape, out_dt, inputs[0].layout.format), out_cn}}, true}; | |||
| } | |||
| OP_TRAIT_REG(Elemwise, Elemwise, opr::Elemwise) | |||
| @@ -40,21 +40,21 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| return {Tensor::make(std::move(hv))}; | |||
| } | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| def.cast_final_safe<GetVarShape>(); | |||
| mgb_assert(inputs.size() == 1, "GetVarShape take 1 input, got %lu", inputs.size()); | |||
| auto&& desc = inputs[0]; | |||
| if (!desc.layout.ndim) { | |||
| return {{TensorLayout(dtype::Int32()), desc.comp_node}}; | |||
| return {{{TensorLayout(dtype::Int32()), desc.comp_node}}, true}; | |||
| } | |||
| DeviceTensorND value(CompNode::default_cpu(), {desc.layout.ndim}, dtype::Int32()); | |||
| auto* ptr = value.ptr<dt_int32>(); | |||
| for (size_t i = 0; i < desc.layout.ndim; ++i) { | |||
| ptr[i] = desc.layout[i]; | |||
| } | |||
| return {{value.layout(), desc.comp_node, std::move(value)}}; | |||
| return {{{value.layout(), desc.comp_node, std::move(value)}}, true}; | |||
| } | |||
| std::shared_ptr<OpDef> make_from_op_node(cg::OperatorNodeBase* node_) { | |||
| @@ -28,12 +28,13 @@ namespace { | |||
| CompNode::UnorderedSet collect_comp_nodes( | |||
| const OpDef& def, const SmallVector<TensorPtr>& inputs) { | |||
| CompNode::UnorderedSet comp_nodes; | |||
| SmallVector<LogicalTensorDesc> descs; | |||
| SmallVector<LogicalTensorDesc> inp_descs; | |||
| for (auto&& i : inputs) { | |||
| comp_nodes.insert(i->comp_node()); | |||
| descs.push_back({i->layout(), i->comp_node(), {}}); | |||
| inp_descs.push_back({i->layout(), i->comp_node(), {}}); | |||
| } | |||
| for (auto&& output_attr : def.infer_output_attrs_fallible(def, descs)) { | |||
| SmallVector<LogicalTensorDesc> oup_descs = std::get<0>(def.infer_output_attrs_fallible(def, inp_descs)); | |||
| for (auto&& output_attr : oup_descs) { | |||
| comp_nodes.insert(output_attr.comp_node); | |||
| } | |||
| return comp_nodes; | |||
| @@ -14,6 +14,7 @@ | |||
| #include "megbrain/graph/static_infer.h" | |||
| #include "megbrain/graph/operator_node.h" | |||
| #include "megbrain/opr/io.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| #include "megbrain/opr/utility.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/imperative/ops/backward_graph.h" | |||
| @@ -590,10 +591,9 @@ cg::OperatorNodeBase* ProxyGraph::get_proxy_opr( | |||
| vinputs[i] = InputPlaceholder::make(*m_graph, *inputs[i]).node(); | |||
| } | |||
| auto opr = OpDef::apply_on_var_node(opdef, vinputs); | |||
| mgb_assert(opr->dyn_typeinfo() != InputPlaceholder::typeinfo()); | |||
| mgb_assert(!opr->same_type<InputPlaceholder>()); | |||
| for (auto &&i : opr->input()) { | |||
| mgb_assert(i->owner_opr()->dyn_typeinfo() == | |||
| InputPlaceholder::typeinfo()); | |||
| mgb_assert(i->owner_opr()->same_type<InputPlaceholder>()); | |||
| } | |||
| return opr; | |||
| } | |||
| @@ -605,17 +605,18 @@ size_t ProxyGraph::get_opr_output_size(const OpDef& opdef, | |||
| return get_proxy_opr(opdef, inputs)->usable_output().size(); | |||
| } | |||
| SmallVector<LogicalTensorDesc> ProxyGraph::infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> ProxyGraph::infer_output_attrs_fallible( | |||
| const OpDef& opdef, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto opr = get_proxy_opr(opdef, inputs); | |||
| CUR_OPR_GUARD(opr); | |||
| do_shape_infer(false); | |||
| SmallVector<LogicalTensorDesc> ret; | |||
| SmallVector<LogicalTensorDesc> outputs; | |||
| bool validated = do_shape_infer(false); | |||
| for (auto&& i : opr->usable_output()) { | |||
| ret.push_back({{i->shape(), i->dtype()}, i->comp_node()}); | |||
| outputs.push_back({{i->shape(), i->dtype()}, i->comp_node()}); | |||
| } | |||
| return ret; | |||
| bool need_check = opr->same_type<opr::Reshape>(); | |||
| return {outputs, validated && !need_check}; | |||
| } | |||
| struct ProxyGraph::GradGraph { | |||
| @@ -811,16 +812,20 @@ VarNodeArray ProxyGraph::make_input_place_holders(const SmallVector<LogicalTenso | |||
| /*********************** Common Impl ***********************/ | |||
| void ProxyGraph::do_shape_infer(bool sync_value) { | |||
| bool ProxyGraph::do_shape_infer(bool sync_value) { | |||
| m_static_infer_manager->update(); | |||
| bool validated = true; | |||
| for (auto* var : m_cur_opr->output()) { | |||
| if (sync_value) { | |||
| var->shape(m_static_infer_manager->infer_shape(var)); | |||
| } else if (auto* shape = m_static_infer_manager->infer_shape_fallible(var)) { | |||
| var->shape(*shape); | |||
| var->shape(*shape); | |||
| } else { | |||
| validated = false; | |||
| } | |||
| } | |||
| return validated; | |||
| } | |||
| TensorPtr ProxyGraph::as_tensor(cg::OperatorNodeBase* opr, bool share) { | |||
| @@ -48,7 +48,7 @@ public: | |||
| const OpDef& opdef, | |||
| const SmallVector<LogicalTensorDesc>& inputs); | |||
| SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& opdef, | |||
| const SmallVector<LogicalTensorDesc>& inputs); | |||
| @@ -88,7 +88,7 @@ private: | |||
| /********************** Common Helper **********************/ | |||
| void do_shape_infer(bool sync_value); | |||
| bool do_shape_infer(bool sync_value); | |||
| TensorPtr as_tensor(cg::OperatorNodeBase* opr, bool share=true); | |||
| @@ -80,8 +80,7 @@ apply_on_physical_tensor(const OpDef& def, | |||
| return outputs; | |||
| } | |||
| SmallVector<LogicalTensorDesc> | |||
| infer_output_attrs_fallible(const OpDef& def, | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs) { | |||
| auto&& graph = ProxyGraph::get_default_graph(); | |||
| return graph->infer_output_attrs_fallible(def, inputs); | |||
| @@ -136,4 +135,4 @@ make_backward_graph(const OpDef& def, | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -21,8 +21,7 @@ SmallVector<TensorPtr> | |||
| apply_on_physical_tensor(const OpDef& def, | |||
| const SmallVector<TensorPtr>& inputs); | |||
| SmallVector<LogicalTensorDesc> | |||
| infer_output_attrs_fallible(const OpDef& def, | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs); | |||
| BackwardGraphResult | |||
| @@ -35,4 +34,4 @@ make_backward_graph(const OpDef& def, | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||
| @@ -44,7 +44,7 @@ public: | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs); | |||
| static SmallVector<LogicalTensorDesc> infer_output_attrs_fallible( | |||
| static std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| const OpDef& def, | |||
| const SmallVector<LogicalTensorDesc>& inputs); | |||
| @@ -38,8 +38,8 @@ public: | |||
| SmallVector<TensorPtr> | |||
| apply(const SmallVector<TensorPtr>& inputs) const; | |||
| SmallVector<LogicalTensorDesc> | |||
| infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const; | |||
| std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs( | |||
| const SmallVector<LogicalTensorDesc>& inputs) const; | |||
| template <typename T, typename F, typename C> | |||
| SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { | |||