GitOrigin-RevId: 6e23456250
tags/v1.6.0-rc1
| @@ -227,19 +227,19 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
| gopt_level = None # disable jit and compile | |||
| binary_ops = { | |||
| "+": builtin.Elemwise(mode="add"), | |||
| "-": builtin.Elemwise(mode="sub"), | |||
| "*": builtin.Elemwise(mode="mul"), | |||
| "/": builtin.Elemwise(mode="true_div"), | |||
| "//": builtin.Elemwise(mode="floor_div"), | |||
| "**": builtin.Elemwise(mode="pow"), | |||
| "√": builtin.Elemwise(mode="expm1"), | |||
| "max": builtin.Elemwise(mode="max"), | |||
| "additive": builtin.Elemwise(mode="add"), | |||
| "+": lambda: builtin.Elemwise(mode="add"), | |||
| "-": lambda: builtin.Elemwise(mode="sub"), | |||
| "*": lambda: builtin.Elemwise(mode="mul"), | |||
| "/": lambda: builtin.Elemwise(mode="true_div"), | |||
| "//": lambda: builtin.Elemwise(mode="floor_div"), | |||
| "**": lambda: builtin.Elemwise(mode="pow"), | |||
| "√": lambda: builtin.Elemwise(mode="expm1"), | |||
| "max": lambda: builtin.Elemwise(mode="max"), | |||
| "additive": lambda: builtin.Elemwise(mode="add"), | |||
| } | |||
| unary_ops = { | |||
| "-": builtin.Elemwise(mode="negate"), | |||
| "-": lambda: builtin.Elemwise(mode="negate"), | |||
| } | |||
| def decorator(func): | |||
| @@ -248,9 +248,9 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
| def apply_expr(op, *args): | |||
| if isinstance(op, str): | |||
| if len(args) == 2: | |||
| op = binary_ops[op] | |||
| op = binary_ops[op]() | |||
| elif len(args) == 1: | |||
| op = unary_ops[op] | |||
| op = unary_ops[op]() | |||
| return builder.apply(op, args, 1)[0] | |||
| def apply_const(value, dtype=dtype, device=device): | |||
| @@ -261,8 +261,8 @@ def subgraph(name, dtype, device, nr_inputs, gopt_level=None): | |||
| builder.outputs(outputs) | |||
| builder.outputs_has_grad(outputs_has_grad) | |||
| if gopt_level is None: | |||
| return builder.get() | |||
| return lambda: builder.get() | |||
| else: | |||
| return builder.compile(gopt_level) | |||
| return lambda: builder.compile(gopt_level) | |||
| return decorator | |||
| @@ -767,6 +767,19 @@ def matinv(inp: Tensor) -> Tensor: | |||
| return result | |||
| class _Hashable: | |||
| def __init__(self, value) -> None: | |||
| self.value = value | |||
| def __hash__(self) -> int: | |||
| return hash(str(self.value)) | |||
| def __eq__(self, o: object) -> bool: | |||
| if not isinstance(o, _Hashable): | |||
| return False | |||
| return self.value == o.value | |||
| @lru_cache(maxsize=None) | |||
| def _get_extentedMatrixMulOp( | |||
| device, dtype, dim1, dim2, transpose_a, transpose_b, compute_mode, format, strategy, | |||
| @@ -833,7 +846,7 @@ def _get_extentedMatrixMulOp( | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| result_shape = f(GetVarShape(), result) | |||
| @@ -954,7 +967,7 @@ def _get_extentedBatchedMatrixMulOp( | |||
| transposeB=transpose_b, | |||
| compute_mode=compute_mode, | |||
| format=format, | |||
| strategy=strategy, | |||
| strategy=strategy.value, | |||
| ) | |||
| result = f(op, inp1, inp2) | |||
| @@ -1051,9 +1064,9 @@ def matmul( | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=get_execution_strategy(), | |||
| strategy=_Hashable(get_execution_strategy()), | |||
| ) | |||
| (result,) = apply(extentedMatrixMulOp, inp1, inp2) | |||
| (result,) = apply(extentedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| else: # dispath to BatchedMatrixMul | |||
| extentedBatchedMatrixMulOp = _get_extentedBatchedMatrixMulOp( | |||
| @@ -1065,9 +1078,9 @@ def matmul( | |||
| transpose_b, | |||
| compute_mode, | |||
| format, | |||
| strategy=get_execution_strategy(), | |||
| strategy=_Hashable(get_execution_strategy()), | |||
| ) | |||
| (result,) = apply(extentedBatchedMatrixMulOp, inp1, inp2) | |||
| (result,) = apply(extentedBatchedMatrixMulOp(), inp1, inp2) | |||
| return result | |||
| @@ -1328,7 +1328,7 @@ def sync_batch_norm( | |||
| syncbn_split_stats, | |||
| ) = _get_sync_bn_ops(_device, _dtype, eps_mode, _ndim, _channels) | |||
| reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0, inp) | |||
| reduce_shape, reduce_size, channel_x1s, channel_x2s = apply(syncbn_stage0(), inp) | |||
| eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) | |||
| @@ -1338,19 +1338,28 @@ def sync_batch_norm( | |||
| if training: | |||
| if is_distributed(): | |||
| # reduce all nodes' data to calculate mean and variance | |||
| (stat,) = apply(syncbn_concat_stats, reduce_size, channel_x1s, channel_x2s) | |||
| (stat,) = apply( | |||
| syncbn_concat_stats(), reduce_size, channel_x1s, channel_x2s | |||
| ) | |||
| stat = all_reduce_sum(stat, group) | |||
| reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats, stat) | |||
| reduce_size, channel_x1s, channel_x2s = apply(syncbn_split_stats(), stat) | |||
| outvar, channel_mean, *_ = apply( | |||
| syncbn_stage1, inp, reduce_size, channel_x1s, channel_x2s, eps, weight, bias | |||
| syncbn_stage1(), | |||
| inp, | |||
| reduce_size, | |||
| channel_x1s, | |||
| channel_x2s, | |||
| eps, | |||
| weight, | |||
| bias, | |||
| ) | |||
| else: | |||
| assert running_var is not None and running_mean is not None | |||
| channel_mean = running_mean | |||
| channel_var = running_var | |||
| outvar, *_ = apply( | |||
| syncbn_stage1_inference, inp, channel_mean, channel_var, eps, weight, bias | |||
| syncbn_stage1_inference(), inp, channel_mean, channel_var, eps, weight, bias | |||
| ) | |||
| # outvar = output * weight + bias | |||
| @@ -1362,7 +1371,7 @@ def sync_batch_norm( | |||
| if training and running_var is not None and running_mean is not None: | |||
| momentum = convert_single_value(momentum, dtype=inp.dtype, device=inp.device) | |||
| running_mean[...], running_var[...] = apply( | |||
| syncbn_stage2, | |||
| syncbn_stage2(), | |||
| running_mean, | |||
| running_var, | |||
| momentum, | |||
| @@ -482,9 +482,15 @@ void init_ops(py::module m) { | |||
| struct PySubgraphBuilder { | |||
| explicit PySubgraphBuilder(std::string name) : name{name}{} | |||
| std::string name; | |||
| Subgraph graph; | |||
| std::shared_ptr<Subgraph> graph_storage = std::make_shared<Subgraph>(); | |||
| std::shared_ptr<UniqueKey> graph_key = std::make_shared<UniqueKey>(); | |||
| Subgraph& graph = *graph_storage; | |||
| mgb::SmallVector<bool> output_grad_mask; | |||
| Subgraph::var_t next_var = 1; | |||
| std::shared_ptr<OpDef> build() const { | |||
| return SubgraphOp::make(name, graph_storage, output_grad_mask, graph_key); | |||
| } | |||
| }; | |||
| py::class_<PySubgraphBuilder>(m, "SubgraphBuilder") | |||
| @@ -518,10 +524,9 @@ void init_ops(py::module m) { | |||
| self.output_grad_mask = outputs_has_grad; | |||
| }) | |||
| .def("get", [](PySubgraphBuilder& self){ | |||
| return (std::shared_ptr<OpDef>)SubgraphOp::make(self.name, self.graph, self.output_grad_mask); | |||
| return (std::shared_ptr<OpDef>)self.build(); | |||
| }) | |||
| .def("compile", [](PySubgraphBuilder& self, int gopt_level){ | |||
| auto op = SubgraphOp::make(self.name, self.graph, self.output_grad_mask); | |||
| return (std::shared_ptr<OpDef>)CompiledOp::make(op, gopt_level); | |||
| return (std::shared_ptr<OpDef>)CompiledOp::make(self.build(), gopt_level); | |||
| }); | |||
| } | |||
| @@ -181,7 +181,7 @@ OP_TRAIT_REG(Identity, Identity) | |||
| namespace { namespace subgraph { | |||
| EncodedSubraph make_forward_graph(const OpDef& def, SmallVector<LogicalTensorDesc> inputs) { | |||
| return EncodedSubraph::make(def.cast_final_safe<SubgraphOp>().graph); | |||
| return EncodedSubraph::make(*def.cast_final_safe<SubgraphOp>().graph); | |||
| } | |||
| EncodedSubraph make_backward_graph( | |||
| @@ -197,16 +197,19 @@ EncodedSubraph make_backward_graph( | |||
| } | |||
| } | |||
| auto bgraph = subgraph_detail::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
| return EncodedSubraph::make_single(SubgraphOp::make(op.name+"Grad", bgraph.graph), bgraph.input_mask, bgraph.output_mask); | |||
| return EncodedSubraph::make_single( | |||
| SubgraphOp::make(op.name + "Grad", | |||
| std::make_shared<Subgraph>(bgraph.graph)), | |||
| bgraph.input_mask, bgraph.output_mask); | |||
| } | |||
| std::vector<std::pair<const char*, std::string>> props(const OpDef& def) { | |||
| auto& op = def.cast_final_safe<SubgraphOp>(); | |||
| return { | |||
| {"name", op.name}, | |||
| {"inputs", mgb::imperative::to_string(op.graph.inputs)}, | |||
| {"exprs", mgb::imperative::to_string(op.graph.exprs)}, | |||
| {"outputs", mgb::imperative::to_string(op.graph.outputs)}, | |||
| {"inputs", mgb::imperative::to_string(op.graph->inputs)}, | |||
| {"exprs", mgb::imperative::to_string(op.graph->exprs)}, | |||
| {"outputs", mgb::imperative::to_string(op.graph->outputs)}, | |||
| }; | |||
| } | |||
| @@ -222,7 +225,7 @@ std::string make_name(const OpDef& def) { | |||
| auto hash(const OpDef& def) { | |||
| auto& op = def.cast_final_safe<SubgraphOp>(); | |||
| if (!op.graph_key) { | |||
| return (size_t)reinterpret_cast<uintptr_t>(&op.graph); | |||
| return (size_t)reinterpret_cast<uintptr_t>(op.graph.get()); | |||
| } | |||
| return op.graph_key->hash(); | |||
| } | |||
| @@ -238,7 +241,7 @@ auto is_same_st(const OpDef& def, const OpDef& another) { | |||
| if (has_graph_key) { | |||
| graph_same = rhs.graph_key && lhs.graph_key->is_same(*rhs.graph_key); | |||
| } else { | |||
| graph_same = !rhs.graph_key && &lhs.graph == &rhs.graph; | |||
| graph_same = !rhs.graph_key && lhs.graph.get() == rhs.graph.get(); | |||
| } | |||
| return graph_same; | |||
| } | |||
| @@ -354,7 +357,9 @@ auto apply_on_physical_tensor( | |||
| auto apply_on_var_node( | |||
| const OpDef& def, | |||
| const VarNodeArray& inputs) { | |||
| return OpDef::apply_on_var_node(*def.cast_final_safe<CompiledOp>().op, inputs); | |||
| auto& op = def.cast_final_safe<CompiledOp>(); | |||
| op.op->set_scope(op.scope()); | |||
| return OpDef::apply_on_var_node(*op.op, inputs); | |||
| } | |||
| auto infer_output_attrs_fallible( | |||
| @@ -397,7 +402,9 @@ EncodedSubraph make_backward_graph( | |||
| if (backward_graph.graph.is_single()) { | |||
| bgraph_op = backward_graph.graph.as_single(); | |||
| } else { | |||
| bgraph_op = SubgraphOp::make(name+"Grad", backward_graph.graph, grad_outputs_has_grad, key); | |||
| bgraph_op = SubgraphOp::make( | |||
| name + "Grad", std::make_shared<Subgraph>(backward_graph.graph), | |||
| grad_outputs_has_grad, key); | |||
| } | |||
| auto compiled_op = CompiledOp::make(bgraph_op, op.gopt_level); | |||
| auto encoded_graph = EncodedSubraph::make_single(compiled_op, backward_graph.input_mask, backward_graph.output_mask); | |||
| @@ -431,6 +438,8 @@ OP_TRAIT_REG(CompiledOp, CompiledOp) | |||
| .fallback(); | |||
| }} | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(UniqueKey); | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(SubgraphOp); | |||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardOpKey); | |||
| @@ -28,7 +28,8 @@ VarNodeArray apply_on_var_node( | |||
| for (auto&& input: inputs) { | |||
| input_descs.push_back({TensorLayout{input->dtype()}, input->comp_node()}); | |||
| } | |||
| auto apply_functor = [](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||
| auto apply_functor = [&](const std::shared_ptr<OpDef>& op, const VarNodeArray& inputs, size_t nr_outputs){ | |||
| op->set_scope(def.scope()); | |||
| return OpDef::apply_on_var_node(*op, inputs); | |||
| }; | |||
| auto const_functor = [&](const TensorPtr& value) { | |||
| @@ -48,16 +48,28 @@ struct ShapeInfer final : OpDefImplBase<ShapeInfer> { | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| }; | |||
| struct UniqueKey final: Hashable { | |||
| public: | |||
| size_t hash() const override { | |||
| return reinterpret_cast<uintptr_t>(this); | |||
| } | |||
| protected: | |||
| bool is_same_st(const Hashable& rhs) const override { | |||
| return this == &rhs.cast_final_safe<UniqueKey>(); | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
| }; | |||
| struct SubgraphOp final: OpDefImplBase<SubgraphOp> { | |||
| std::string name; | |||
| Subgraph graph; | |||
| std::shared_ptr<Subgraph> graph; | |||
| SmallVector<bool> output_grad_mask; | |||
| std::shared_ptr<Hashable> graph_key; | |||
| SubgraphOp() = default; | |||
| SubgraphOp(std::string name, Subgraph graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr) | |||
| SubgraphOp(std::string name, std::shared_ptr<Subgraph> graph, SmallVector<bool> output_grad_mask={}, std::shared_ptr<Hashable> key=nullptr) | |||
| : name{name}, graph{graph}, output_grad_mask{output_grad_mask}, graph_key{std::move(key)}{ | |||
| if (this->output_grad_mask.empty()) { | |||
| this->output_grad_mask.resize(graph.outputs.size(), true); | |||
| this->output_grad_mask.resize(graph->outputs.size(), true); | |||
| } | |||
| } | |||
| MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||