GitOrigin-RevId: 27dde05cff
tags/v1.9.0
| @@ -19,6 +19,7 @@ | |||
| #include "range/v3/all.hpp" | |||
| #include "./helper.h" | |||
| #include "./transformation.h" | |||
| namespace py = pybind11; | |||
| @@ -30,9 +31,7 @@ namespace { | |||
| std::unordered_map<std::shared_ptr<GradKey>, GradKeyWrapper*> grad_key_map; | |||
| } | |||
| GradKeyWrapper::GradKeyWrapper() : m_key(std::make_shared<GradKey>()) { | |||
| grad_key_map[m_key] = this; | |||
| } | |||
| GradKeyWrapper::GradKeyWrapper() {} | |||
| void GradKeyWrapper::attach(PyObject* const* args, size_t nargs) { | |||
| if (nargs != 2) { | |||
| @@ -77,8 +76,8 @@ pybind11::function GradKeyWrapper::get_backward_closure( | |||
| for (auto&& tensor : tensors) { | |||
| args.push_back(TensorWrapper::try_cast(tensor.ptr())->m_tensor->data()); | |||
| } | |||
| auto closure = imperative::apply(GetBackwardColsure(self->m_key), args)[0] | |||
| .as<FunctionValue>(); | |||
| auto closure_value = imperative::apply(GetBackwardColsure(self->m_key), args)[0]; | |||
| auto closure = closure_value.as_ref<FunctionValue>(); | |||
| auto py_function = [closure](std::vector<TensorWrapper*> tensors) { | |||
| std::vector<ValueRef> args; | |||
| for (auto* tw : tensors) { | |||
| @@ -90,11 +89,14 @@ pybind11::function GradKeyWrapper::get_backward_closure( | |||
| } | |||
| PyObject* GradKeyWrapper::get_name() { | |||
| return py::cast(m_key->name()).release().ptr(); | |||
| return py::cast(m_name).release().ptr(); | |||
| } | |||
| void GradKeyWrapper::set_name(py::handle name) { | |||
| m_key->name(py::cast<std::string>(name)); | |||
| m_name = py::cast<std::string>(name); | |||
| if (m_key) { | |||
| m_key->name(m_name); | |||
| } | |||
| } | |||
| PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | |||
| @@ -115,7 +117,10 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject* const* args, size_t nargs) { | |||
| } | |||
| void GradKeyWrapper::enter() { | |||
| m_transformation = std::make_shared<GradTransformation>(m_key); | |||
| m_transformation = std::make_shared<GradTransformation>(); | |||
| m_key = m_transformation->key(); | |||
| m_key->name(m_name); | |||
| grad_key_map[m_key] = this; | |||
| TransformationManager::get_instance().register_at<TransformationManager::Grad>( | |||
| m_transformation); | |||
| } | |||
| @@ -123,6 +128,8 @@ void GradKeyWrapper::enter() { | |||
| void GradKeyWrapper::exit() { | |||
| TransformationManager::get_instance().unregister<TransformationManager::Grad>( | |||
| m_transformation); | |||
| grad_key_map.erase(m_key); | |||
| m_key = {}; | |||
| m_transformation.reset(); | |||
| } | |||
| @@ -138,8 +145,6 @@ GradKeyWrapper* GradKeyWrapper::get(std::shared_ptr<GradKey> key) { | |||
| return grad_key_map.at(key); | |||
| } | |||
| GradKeyWrapper::~GradKeyWrapper() { | |||
| grad_key_map.erase(m_key); | |||
| } | |||
| GradKeyWrapper::~GradKeyWrapper() {} | |||
| } // namespace mgb::imperative::python | |||
| @@ -26,6 +26,7 @@ struct GradKeyWrapper : NonCopyableObj { | |||
| using wrap_t = pyext17::wrap<GradKeyWrapper>; | |||
| static constexpr auto tp_name = pybind11::detail::_("GradKey"); | |||
| std::string m_name; | |||
| std::shared_ptr<GradKey> m_key; | |||
| std::shared_ptr<GradTransformation> m_transformation; | |||
| @@ -117,7 +117,7 @@ std::optional<ValueRefList> elemwise_grad_rule( | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(2); | |||
| SmallVector<ValueRef> ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| @@ -147,7 +147,7 @@ std::optional<ValueRefList> reshape_grad_rule( | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(2); | |||
| SmallVector<ValueRef> ret(2); | |||
| if (!grad) { | |||
| return ret; | |||
| } | |||
| @@ -180,7 +180,7 @@ std::optional<ValueRefList> subtensor_grad_rule( | |||
| grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(1); | |||
| SmallVector<ValueRef> ret(1); | |||
| if (grad && inputs[0]) { | |||
| ValueRefList args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
| @@ -215,7 +215,7 @@ std::optional<ValueRefList> indexingMultiAxisVec_grad_rule( | |||
| grad_op_ = std::move(grad_op)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(1); | |||
| SmallVector<ValueRef> ret(1); | |||
| if (grad && inputs[0]) { | |||
| ValueRefList args_(inputs.size() + 1); | |||
| auto&& zeros = make_empty_tensor(grad.device(), inputs[0], grad.dtype()); | |||
| @@ -251,7 +251,7 @@ std::optional<ValueRefList> reduce_grad_rule( | |||
| maker.backward([shapes = std::move(input_shapes)](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(1); | |||
| SmallVector<ValueRef> ret(1); | |||
| if (grad && shapes[0]) { | |||
| ret[0] = broadcast_to(grad, shapes[0]); | |||
| } | |||
| @@ -274,7 +274,7 @@ std::optional<ValueRefList> addAxis_grad_rule( | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(1); | |||
| SmallVector<ValueRef> ret(1); | |||
| if (grad && flag_) { | |||
| ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
| } | |||
| @@ -297,7 +297,7 @@ std::optional<ValueRefList> removeAxis_grad_rule( | |||
| maker.backward([grad_op_ = std::move(grad_op), flag_ = flag](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(1); | |||
| SmallVector<ValueRef> ret(1); | |||
| if (grad && flag_) { | |||
| ret[0] = imperative::apply(*grad_op_, grad)[0]; | |||
| } | |||
| @@ -316,7 +316,7 @@ std::optional<ValueRefList> fastpathcopy_grad_rule( | |||
| maker.backward([](Span<ValueRef> grads) { | |||
| mgb_assert(grads.size() == 1); | |||
| ValueRef grad = grads[0]; | |||
| ValueRefList ret(1); | |||
| SmallVector<ValueRef> ret(1); | |||
| if (grad) { | |||
| ret[0] = grad; | |||
| } | |||
| @@ -56,42 +56,44 @@ WeakKeyMap<ValueWeakRef, py::object> module_trace_info_map; | |||
| struct SymbolVarContext { | |||
| TransformationContext context; | |||
| cg::ComputingGraph* graph; | |||
| std::shared_ptr<SymbolTransformation> symbol_tsf; | |||
| std::shared_ptr<ScalarTransformation> scalar_tsf; | |||
| SymbolVarContext(cg::ComputingGraph* graph) : graph(graph) { | |||
| SymbolVarContext(cg::ComputingGraph* graph) { | |||
| symbol_tsf = std::make_shared<SymbolTransformation>(graph); | |||
| scalar_tsf = std::make_shared<ScalarTransformation>(); | |||
| Transformation::swap_context(context); | |||
| } | |||
| void init() { | |||
| std::make_shared<SymbolTransformation>(graph)->register_at( | |||
| Transformation::top()); | |||
| std::make_shared<ScalarTransformation>()->register_at(Transformation::top()); | |||
| symbol_tsf->register_at(Transformation::top()); | |||
| scalar_tsf->register_at(Transformation::top()); | |||
| } | |||
| ~SymbolVarContext() { Transformation::swap_context(context); } | |||
| }; | |||
| ValueRef symvar2val(py::handle py_symbol_var) { | |||
| auto* symbol_var = py_symbol_var.cast<PySymbolVar*>(); | |||
| ValueRef value = symbol_tsf->value_type().make(symbol_var->m_node); | |||
| if (symbol_var->is_scalar) { | |||
| value = scalar_tsf->value_type().make(value); | |||
| } | |||
| return value; | |||
| } | |||
| ValueRef symvar2val(py::handle py_symbol_var) { | |||
| auto* symbol_var = py_symbol_var.cast<PySymbolVar*>(); | |||
| ValueRef value = SymbolValue::make(symbol_var->m_node); | |||
| if (symbol_var->is_scalar) { | |||
| value = ScalarValue::make(value); | |||
| py::object val2symvar(py::handle typeobj, ValueRef value) { | |||
| bool is_scalar = false; | |||
| if (auto* scalar_value = value.as(scalar_tsf->value_type())) { | |||
| value = scalar_value->value(); | |||
| is_scalar = true; | |||
| } | |||
| auto* node = value.cast(symbol_tsf->value_type()).node(); | |||
| auto py_symbol_var = | |||
| typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||
| py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar; | |||
| return py_symbol_var; | |||
| } | |||
| return value; | |||
| } | |||
| py::object val2symvar(py::handle typeobj, ValueRef value) { | |||
| bool is_scalar = false; | |||
| if (auto* scalar_value = value.as<ScalarValue>()) { | |||
| value = scalar_value->value(); | |||
| is_scalar = true; | |||
| } | |||
| auto* node = value.cast<SymbolValue>().node(); | |||
| auto py_symbol_var = | |||
| typeobj(pybind11::cast(node, pybind11::return_value_policy::automatic)); | |||
| py_symbol_var.cast<PySymbolVar*>()->is_scalar = is_scalar; | |||
| return py_symbol_var; | |||
| } | |||
| ~SymbolVarContext() { Transformation::swap_context(context); } | |||
| }; | |||
| } // namespace | |||
| @@ -130,19 +132,21 @@ PyObject* py_apply( | |||
| auto op = py::handle(py_op).cast<std::shared_ptr<OpDef>>(); | |||
| SmallVector<ValueRef, 8> tensors(nargs); | |||
| if (py::isinstance<PySymbolVar>(py::handle(args[0]))) { | |||
| bool is_symbol_var = (!TensorWrapper::try_cast(args[0])) && | |||
| py::isinstance<PySymbolVar>(py::handle(args[0])); | |||
| if (is_symbol_var) { | |||
| // swap to a special context to reuse scalar handle | |||
| SymbolVarContext context( | |||
| py::handle(args[0]).cast<PySymbolVar*>()->m_node->owner_graph()); | |||
| context.init(); | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| tensors[i] = symvar2val(args[i]); | |||
| tensors[i] = context.symvar2val(args[i]); | |||
| } | |||
| auto outputs = imperative::apply(*op, tensors); | |||
| auto ret = pybind11::tuple(outputs.size()); | |||
| auto typeobj = py::handle(args[0]).get_type(); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| ret[i] = val2symvar(typeobj, outputs[i]); | |||
| ret[i] = context.val2symvar(typeobj, outputs[i]); | |||
| } | |||
| return ret.release().ptr(); | |||
| } | |||
| @@ -161,7 +165,7 @@ PyObject* py_apply( | |||
| } | |||
| } | |||
| auto outputs = imperative::apply(*op, tensors); | |||
| auto outputs = [&] { return imperative::apply(*op, tensors); }(); | |||
| size_t nout = outputs.size(); | |||
| auto ret = py::tuple(nout); | |||
| for (size_t i = 0; i < nout; ++i) { | |||
| @@ -1573,9 +1577,9 @@ void init_tensor(py::module m) { | |||
| SymbolVarContext context(graph); | |||
| context.init(); | |||
| auto output = reduce_to_scalar( | |||
| *op.cast<std::shared_ptr<OpDef>>(), symvar2val(tensor)); | |||
| *op.cast<std::shared_ptr<OpDef>>(), context.symvar2val(tensor)); | |||
| auto typeobj = tensor.get_type(); | |||
| return val2symvar(typeobj, output); | |||
| return context.val2symvar(typeobj, output); | |||
| } else { | |||
| auto* tw = TensorWrapper::try_cast(tensor.ptr()); | |||
| auto output = reduce_to_scalar( | |||
| @@ -67,10 +67,9 @@ struct TransformationManager { | |||
| } | |||
| }; | |||
| class PyValue final | |||
| : public MixinValueImpl<PyValue, ValueKind::Object, pybind11::object> { | |||
| class PyValue final : public PrimitiveValue<PyValue, pybind11::object> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const { | |||
| return pybind11::str((const pybind11::object&)*this).cast<std::string>(); | |||
| @@ -63,7 +63,7 @@ auto CreateTensor::parse(Span<ValueRef> inputs) const -> Args { | |||
| MegBrainError, | |||
| "unknown input type, expects HostStorage or DeviceStorage, got " | |||
| "%s", | |||
| input.name()->c_str()); | |||
| input.to_string().c_str()); | |||
| } | |||
| } | |||
| mgb_assert( | |||
| @@ -12,7 +12,7 @@ std::string CompNodeValue::to_string() const { | |||
| } | |||
| std::string BoolValue::to_string() const { | |||
| return (*m_value) ? "true" : "false"; | |||
| return (*this) ? "true" : "false"; | |||
| } | |||
| std::string HostStorage::to_string() const { | |||
| @@ -26,10 +26,10 @@ std::string DeviceStorage::to_string() const { | |||
| std::string HostValue::to_string() const { | |||
| return ssprintf( | |||
| "HostValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), | |||
| m_dtype.name(), m_shape.to_string().c_str()); | |||
| dtype().name(), shape().to_string().c_str()); | |||
| } | |||
| HostTensorND HostValue::as_nd(bool allow_scalar) const { | |||
| HostTensorND HostTensor::as_nd(bool allow_scalar) const { | |||
| HostTensorND nd; | |||
| TensorShape tensor_shape; | |||
| if (m_shape.is_scalar()) { | |||
| @@ -45,10 +45,10 @@ HostTensorND HostValue::as_nd(bool allow_scalar) const { | |||
| std::string DeviceValue::to_string() const { | |||
| return ssprintf( | |||
| "DeviceValue{device=%s, dtype=%s, shape=%s}", device().to_string().c_str(), | |||
| m_dtype.name(), m_shape.to_string().c_str()); | |||
| dtype().name(), shape().to_string().c_str()); | |||
| } | |||
| DeviceTensorND DeviceValue::as_nd(bool allow_scalar) const { | |||
| DeviceTensorND DeviceTensor::as_nd(bool allow_scalar) const { | |||
| DeviceTensorND nd; | |||
| TensorShape tensor_shape; | |||
| if (m_shape.is_scalar()) { | |||
| @@ -19,46 +19,18 @@ | |||
| namespace mgb { | |||
| namespace imperative { | |||
| namespace { | |||
| MGB_NOINLINE void copy_outputs( | |||
| ForwardAllocator<ValueRef>& allocator, ValueRefList& outputs) { | |||
| size_t nr_outputs = outputs.size(); | |||
| if (mgb_likely(nr_outputs == 1)) { | |||
| ValueRef output_copy; | |||
| output_copy = outputs[0]; | |||
| allocator.clear(); | |||
| outputs = ValueRefList({output_copy}); | |||
| } else if (!outputs.empty()) { | |||
| SmallVector<ValueRef> outputs_copy(nr_outputs); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| outputs_copy[i] = outputs[i]; | |||
| } | |||
| outputs.clear(); | |||
| allocator.clear(); | |||
| outputs = {outputs_copy.begin(), outputs_copy.end()}; | |||
| } else { | |||
| allocator.clear(); | |||
| } | |||
| } | |||
| } // namespace | |||
| ValueRefList apply(const Operator& op, Span<ValueRef> inputs) { | |||
| auto& context = Transformation::get_context(); | |||
| size_t& depth = context.next_transformation; | |||
| bool top = depth == 0; | |||
| auto outputs = ([&] { | |||
| if (mgb_unlikely(depth >= context.transformations.size())) { | |||
| return op.fallback(inputs); | |||
| } else { | |||
| auto& transformation = *context.transformations[depth++]; | |||
| CleanupGuard _{[&] { --depth; }}; | |||
| return transformation.apply_transformation(op, inputs); | |||
| } | |||
| })(); | |||
| if (mgb_unlikely(top)) { | |||
| copy_outputs(context.allocator, outputs); | |||
| // TODO: add fallback transformation | |||
| bool fallback = depth >= context.transformations.size(); | |||
| if (mgb_unlikely(fallback)) { | |||
| return op.fallback(inputs); | |||
| } else { | |||
| auto& transformation = *context.transformations[depth++]; | |||
| CleanupGuard _{[&] { --depth; }}; | |||
| return transformation.apply_transformation(op, inputs); | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) { | |||
| @@ -66,12 +38,7 @@ ValueRefList apply(const OpDef& def, Span<ValueRef> inputs) { | |||
| } | |||
| ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) { | |||
| SmallVector<ValueRef> inputs_storage; | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| inputs_storage.push_back(inputs[i]); | |||
| } | |||
| auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<ValueRef> inputs, | |||
| size_t) { | |||
| auto apply_functor = [](std::shared_ptr<OpDef> op, Span<ValueRef> inputs, size_t) { | |||
| auto outputs = imperative::apply(*op, inputs); | |||
| return SmallVector<ValueRef>(outputs.begin(), outputs.end()); | |||
| }; | |||
| @@ -93,7 +60,7 @@ ValueRefList apply(const Subgraph& graph, Span<ValueRef> inputs) { | |||
| HostStorage::make(host_value.storage()), | |||
| DeviceStorage::make(device_value.storage()))[0]; | |||
| }; | |||
| auto outputs = graph.apply(inputs_storage, apply_functor, make_const); | |||
| auto outputs = graph.apply(inputs, apply_functor, make_const); | |||
| return ValueRefList{outputs.begin(), outputs.end()}; | |||
| } | |||
| @@ -331,6 +331,7 @@ void ChannelImpl::dispatch_kernel( | |||
| cmd.inputs = std::move(input_infos); | |||
| cmd.outputs.reserve(output_descs.size()); | |||
| outputs->reserve(output_descs.size()); | |||
| for (int i = 0; i < output_descs.size(); ++i) { | |||
| auto&& desc = output_descs[i]; | |||
| auto info = alloc(); | |||
| @@ -730,7 +731,8 @@ void ChannelImpl::do_apply_op(const ApplyOp& cmd, std::string reason) { | |||
| input_descs.push_back({{{}, input->dtype()}, input->comp_node()}); | |||
| } | |||
| auto forward_graph = OpDef::make_forward_graph(def, input_descs); | |||
| auto outputs = forward_graph.apply(inputs, apply_functor, const_functor); | |||
| auto outputs = forward_graph.apply<TensorPtr>( | |||
| inputs, apply_functor, const_functor); | |||
| return outputs; | |||
| } | |||
| return OpDef::apply_on_physical_tensor(def, inputs); | |||
| @@ -11,6 +11,7 @@ | |||
| #include "megbrain/imperative/opr_utility.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/utils/stats.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/utility.h" | |||
| @@ -101,7 +102,7 @@ void apply_on_device_tensornd( | |||
| const OpDef& def, const SmallVector<DeviceTensorND>& inputs, | |||
| SmallVector<DeviceTensorND>* outputs) { | |||
| auto&& op_def = def.cast_final_safe<Elemwise>(); | |||
| auto trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | |||
| auto&& trait = megdnn::Elemwise::ModeTrait::from_mode(op_def.mode); | |||
| mgb_assert( | |||
| inputs.size() == trait.arity, "%s expects %u inputs; got %zu actually", | |||
| trait.name, trait.arity, inputs.size()); | |||
| @@ -36,7 +36,7 @@ VarNodeArray apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| .node(); | |||
| }; | |||
| auto subgraph = def.trait()->make_forward_graph(def, input_descs); | |||
| auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||
| auto outputs = subgraph.apply<VarNode*>(inputs, apply_functor, const_functor); | |||
| return outputs; | |||
| } | |||
| @@ -56,7 +56,8 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||
| value->layout(), value->comp_node(), | |||
| value->get_value().proxy_to_default_cpu()}; | |||
| }; | |||
| auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||
| auto outputs = | |||
| subgraph.apply<LogicalTensorDesc>(inputs, apply_functor, const_functor); | |||
| return {outputs, all_validated}; | |||
| } | |||
| @@ -72,7 +73,7 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||
| return OpDef::apply_on_physical_tensor(*op, inputs); | |||
| }; | |||
| auto const_functor = [&](const TensorPtr& value) { return value; }; | |||
| auto outputs = subgraph.apply(inputs, apply_functor, const_functor); | |||
| auto outputs = subgraph.apply<TensorPtr>(inputs, apply_functor, const_functor); | |||
| return outputs; | |||
| } | |||
| @@ -94,7 +95,7 @@ static EncodedSubgraph make_backward_graph_from_forward( | |||
| }; | |||
| GradContext<var_t> grad_context{accum_grad}; | |||
| auto input_vars = builder.write_inputs(inputs); | |||
| auto outputs = forward_graph.apply( | |||
| auto outputs = forward_graph.apply<var_t>( | |||
| input_vars, std::bind(&decltype(builder)::write_expr, &builder, _1, _2, _3), | |||
| [&](TensorPtr constant) { | |||
| return builder.write_constant( | |||
| @@ -102,7 +103,7 @@ static EncodedSubgraph make_backward_graph_from_forward( | |||
| }); | |||
| size_t nr_outputs = outputs.size(); | |||
| auto apply_mask = [](auto&& values, SmallVector<bool> mask) { | |||
| mgb_assert(mask.size() == values.size(), ""); | |||
| mgb_assert(mask.size() == values.size()); | |||
| std::decay_t<decltype(values)> results; | |||
| for (size_t i = 0; i < mask.size(); ++i) { | |||
| if (mask[i]) { | |||
| @@ -143,7 +144,7 @@ static EncodedSubgraph make_backward_graph_from_forward( | |||
| return builder.write_constant( | |||
| constant, {constant->layout(), constant->comp_node()}); | |||
| }; | |||
| return bg.apply(grad_inputs, apply_functor, const_functor); | |||
| return bg.apply<var_t>(grad_inputs, apply_functor, const_functor); | |||
| }); | |||
| builder.add_outputs(grad_context.get_grads(input_vars)); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| @@ -10,20 +10,19 @@ | |||
| */ | |||
| #include "megbrain/imperative/transformations/eval.h" | |||
| #include "megbrain/imperative/transformations/grad.h" | |||
| #include "megbrain/imperative/utils/stats.h" | |||
| namespace mgb { | |||
| namespace imperative { | |||
| DTypeValue::ref_t InterpreterInfo::dtype() const { | |||
| DTypeValue::ref_t InterpreterValue::dtype() const { | |||
| if (!m_dtype) { | |||
| m_dtype = DTypeValue::make(handle()->channel()->get_dtype(handle()->handle())); | |||
| } | |||
| return m_dtype; | |||
| } | |||
| CompNodeValue::ref_t InterpreterInfo::comp_node() const { | |||
| CompNodeValue::ref_t InterpreterValue::comp_node() const { | |||
| if (!m_comp_node) { | |||
| m_comp_node = CompNodeValue::make( | |||
| handle()->channel()->get_device(handle()->handle())); | |||
| @@ -31,7 +30,7 @@ CompNodeValue::ref_t InterpreterInfo::comp_node() const { | |||
| return m_comp_node; | |||
| } | |||
| ShapeValue::ref_t InterpreterInfo::shape() const { | |||
| ShapeValue::ref_t InterpreterValue::shape() const { | |||
| if (!m_shape) { | |||
| m_shape = ShapeValue::make( | |||
| ValueShape::from(handle()->channel()->get_shape(handle()->handle()))); | |||
| @@ -51,21 +50,22 @@ ValueRefList InterpreterTransformation::apply_op( | |||
| } | |||
| }}; | |||
| for (auto input : inputs) { | |||
| input_handles.push_back(input.cast<InterpreterValue>().handle()->handle()); | |||
| input_handles.push_back(input.cast(m_value_type).handle()->handle()); | |||
| } | |||
| output_handles = | |||
| m_channel->apply_op(apply_op.op().shared_from_this(), input_handles); | |||
| ValueRefList outputs(output_handles.size()); | |||
| for (size_t i = 0; i < output_handles.size(); ++i) { | |||
| outputs[i] = InterpreterValue::make(share_handle(output_handles[i])); | |||
| outputs[i] = m_value_type.make(share_handle(output_handles[i])); | |||
| output_handles[i] = nullptr; | |||
| } | |||
| output_handles.clear(); | |||
| return outputs; | |||
| } | |||
| ValueRefList InterpreterTransformation::apply_get_attr( | |||
| const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
| auto& input = inputs.item().cast<InterpreterValue>(); | |||
| auto& input = inputs.item().cast(m_value_type); | |||
| ValueRef output; | |||
| switch (get_attr.attr()) { | |||
| case GetAttr::DType: | |||
| @@ -98,10 +98,10 @@ ValueRefList InterpreterTransformation::apply_create_tensor( | |||
| if (!args.device) { | |||
| // implies H2D | |||
| mgb_assert(args.host, "neither host and device value is valid"); | |||
| return {InterpreterValue::make(share_handle( | |||
| return {m_value_type.make(share_handle( | |||
| m_channel->put(*args.host, args.kind == CreateTensor::Unique)))}; | |||
| } else { | |||
| return {InterpreterValue::make(share_handle(m_channel->put( | |||
| return {m_value_type.make(share_handle(m_channel->put( | |||
| *args.device, args.host ? *args.host : HostTensorND())))}; | |||
| } | |||
| } | |||
| @@ -119,7 +119,7 @@ ValueRefList InterpreterTransformation::apply_transformation( | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| return apply_create_tensor(*create_tensor, inputs); | |||
| } else if (auto* dtr_command = op.as<DTRCommand>()) { | |||
| auto handle = inputs[0].cast<InterpreterValue>().handle()->handle(); | |||
| auto handle = inputs[0].cast(m_value_type).handle()->handle(); | |||
| switch (dtr_command->kind()) { | |||
| case DTRCommand::Drop: | |||
| m_channel->drop(handle); | |||
| @@ -129,10 +129,10 @@ ValueRefList InterpreterTransformation::apply_transformation( | |||
| } | |||
| return {}; | |||
| } else if (auto* rename_value = op.as<RenameValue>()) { | |||
| auto& input = inputs[0].cast<InterpreterValue>(); | |||
| return {InterpreterValue::make(input.handle(), rename_value->name())}; | |||
| auto& input = inputs[0].cast(m_value_type); | |||
| return {m_value_type.make(input.handle(), rename_value->name())}; | |||
| } else if (op.is<GetName>()) { | |||
| auto name = inputs[0].cast<InterpreterValue>().name(); | |||
| auto name = inputs[0].cast(m_value_type).name(); | |||
| if (!name.empty()) { | |||
| return {StringValue::make(name)}; | |||
| } else { | |||
| @@ -68,7 +68,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
| size_t count = std::count_if( | |||
| save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); | |||
| if (!backward_graph->precomp.empty()) { | |||
| ValueRefList inputs_and_outputs(inputs.size() + outputs.size()); | |||
| SmallVector<ValueRef> inputs_and_outputs(inputs.size() + outputs.size()); | |||
| auto it = inputs_and_outputs.begin(); | |||
| for (auto&& input : inputs) { | |||
| *it++ = input; | |||
| @@ -94,7 +94,7 @@ BackwardGraphWithClosure::BackwardGraphWithClosure( | |||
| } | |||
| } | |||
| void BackwardGraphWithClosure::operator()( | |||
| ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| ValueRef args[closure.size() + grads.size()]; | |||
| size_t nargs = 0; | |||
| for (auto&& value : closure) { | |||
| @@ -114,7 +114,9 @@ void BackwardGraphWithClosure::operator()( | |||
| if (null_grad) { | |||
| return; | |||
| } | |||
| auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||
| auto igrads_ = imperative::apply(backward_graph->backward, Span(args, nargs)); | |||
| SmallVector<ValueRef> igrads = {igrads_.begin(), igrads_.end()}; | |||
| igrads_.clear(); | |||
| auto&& iter = igrads.begin(); | |||
| for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { | |||
| if (p) { | |||
| @@ -125,7 +127,7 @@ void BackwardGraphWithClosure::operator()( | |||
| } | |||
| void CustomBackward::operator()( | |||
| ValueRefList grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { | |||
| size_t nargs = grads.size(); | |||
| ValueRef args[nargs]; | |||
| for (size_t i = 0; i < nargs; ++i) { | |||
| @@ -206,7 +208,7 @@ void GradKey::backward() { | |||
| mgb_throw(AssertionError, "invalid backward"); | |||
| } else { | |||
| mgb_assert(grad_fn->m_slots.size() > 0); | |||
| ValueRefList grads (grad_fn->m_slots.size()); | |||
| SmallVector<ValueRef> grads (grad_fn->m_slots.size()); | |||
| auto iter = grads.begin(); | |||
| for (auto&& slot : grad_fn->m_slots) { | |||
| *iter++ = slot.m_grad; | |||
| @@ -231,11 +233,9 @@ void GradKey::backward() { | |||
| GradValue::ref_t GradKey::attach( | |||
| ValueRef tensor, std::function<void(ValueRef)> callback) { | |||
| auto grad_value = tensor.as_ref<GradValue>(); | |||
| if (grad_value && grad_value->has_key(shared_from_this())) { | |||
| mgb_assert( | |||
| !tensor.cast<GradValue>().slot_for(shared_from_this())->callback, | |||
| "callback exists"); | |||
| auto grad_value = tensor.as_ref(m_value_type); | |||
| if (grad_value) { | |||
| mgb_assert(!tensor.cast(m_value_type).slot()->callback, "callback exists"); | |||
| } else { | |||
| GradSlotPtr grad_slot; | |||
| auto& grad_fn = grad_slot.m_fn; | |||
| @@ -243,9 +243,9 @@ GradValue::ref_t GradKey::attach( | |||
| grad_fn->m_key = shared_from_this(); | |||
| grad_fn->m_slots.resize(1); | |||
| grad_slot.m_index = 0; | |||
| grad_value = GradValue::make(tensor, shared_from_this(), grad_slot); | |||
| grad_value = m_value_type.make(tensor, shared_from_this(), grad_slot); | |||
| } | |||
| grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback; | |||
| grad_value->slot().m_fn->m_slots[0].callback = callback; | |||
| return grad_value; | |||
| } | |||
| @@ -263,7 +263,7 @@ void GradKey::freeze() { | |||
| ValueRefList GradTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| auto fallback = [&] { | |||
| ValueRefList unwrapped_inputs(inputs.size()); | |||
| SmallVector<ValueRef> unwrapped_inputs(inputs.size()); | |||
| { | |||
| // overhead | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| @@ -367,7 +367,7 @@ ValueRefList GradTransformation::apply_transformation( | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (backward.input_has_grad(i) && require_grads[i]) { | |||
| auto& input_grad_slot = | |||
| inputs[i].cast<GradValue>().slot_for(m_key); | |||
| inputs[i].cast(m_value_type).slot(); | |||
| grad_fn->m_dests.emplace_back(input_grad_slot); | |||
| grad_fn->m_dests.back().m_producer_record.insert_after( | |||
| input_grad_slot->m_producer_head); | |||
| @@ -378,7 +378,7 @@ ValueRefList GradTransformation::apply_transformation( | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| if (backward.output_requires_grad(i)) { | |||
| // little overhead: Value::make | |||
| auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); | |||
| auto grad_value = m_value_type.make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); | |||
| outputs[i] = record_grad(grad_value); | |||
| } | |||
| } | |||
| @@ -435,7 +435,10 @@ ValueRefList GradTransformation::apply_transformation( | |||
| backward.m_input_has_grad = SmallVector(nr_inputs, true); | |||
| backward.m_output_attrs = | |||
| SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); | |||
| backward.m_backward = set_grad->grad_fn(); | |||
| backward.m_backward = [fn = set_grad->grad_fn()](Span<ValueRef> inputs) { | |||
| auto result = fn(inputs); | |||
| return SmallVector<ValueRef>(result.begin(), result.end()); | |||
| }; | |||
| ValueRefList outputs(nr_outputs); | |||
| grad_fn->m_key = m_key; | |||
| grad_fn->m_slots.resize(nr_outputs); | |||
| @@ -454,10 +457,10 @@ ValueRefList GradTransformation::apply_transformation( | |||
| auto& output = outputs_[i]; | |||
| auto grad_value = as_grad_value(output); | |||
| if (grad_value) { | |||
| grad_value = GradValue::make( | |||
| grad_value = m_value_type.make( | |||
| grad_value->m_value, m_key, GradSlotPtr(grad_fn, i)); | |||
| } else { | |||
| grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); | |||
| grad_value = m_value_type.make(output, m_key, GradSlotPtr(grad_fn, i)); | |||
| } | |||
| outputs[i] = record_grad(grad_value); | |||
| } | |||
| @@ -485,8 +488,7 @@ ValueRefList GradTransformation::apply_transformation( | |||
| mgb_assert(inputs.size() == 1); | |||
| if (auto&& grad_value = as_grad_value(inputs[0])) { | |||
| auto output = imperative::apply(op, grad_value->m_value)[0]; | |||
| auto grad_output = GradValue::make( | |||
| output, grad_value->key(), grad_value->slot_for(m_key)); | |||
| auto grad_output = m_value_type.make(output, m_key, grad_value->slot()); | |||
| return {record_grad(grad_output)}; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| @@ -502,7 +504,7 @@ GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { | |||
| std::vector<GradSlotPtr> y_slots; | |||
| for (auto&& y : ys) { | |||
| if (auto&& grad_value = as_grad_value(y)) { | |||
| y_slots.push_back(grad_value->slot_for(grad_key)); | |||
| y_slots.push_back(grad_value->slot()); | |||
| } else { | |||
| y_slots.emplace_back(); | |||
| } | |||
| @@ -32,7 +32,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( | |||
| bool require_link = mm_io_ops.count(op_val->op().dyn_typeinfo()); | |||
| VarNodeArray input_nodes; | |||
| for (auto&& input : inputs) { | |||
| if (auto* input_node = input.as<LazyEvalValue>()) { | |||
| if (auto* input_node = input.as(m_value_type)) { | |||
| input_nodes.push_back(input_node->node()); | |||
| } else { | |||
| // ImmutableTensor has empty shape issues | |||
| @@ -112,7 +112,7 @@ ValueRefList LazyEvalTransformation::apply_transformation( | |||
| return {record_var(node)}; | |||
| } | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) { | |||
| if (auto* lazy_val = inputs.item().as(m_value_type)) { | |||
| switch (get_attr->attr()) { | |||
| case GetAttr::DType: | |||
| return {DTypeValue::make(lazy_val->node()->dtype())}; | |||
| @@ -167,14 +167,14 @@ ValueRefList LazyEvalTransformation::apply_transformation( | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (auto* rename_value = op.as<RenameValue>()) { | |||
| if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) { | |||
| if (auto* lazy_val = inputs.item().as(m_value_type)) { | |||
| return {record_var( | |||
| lazy_val->node(), lazy_val->bound_data(), rename_value->name())}; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (op.is<GetName>()) { | |||
| if (auto* lazy_val = inputs.item().as<LazyEvalValue>()) { | |||
| if (auto* lazy_val = inputs.item().as(m_value_type)) { | |||
| auto name = lazy_val->name(); | |||
| if (!name.empty()) { | |||
| return {StringValue::make(lazy_val->name())}; | |||
| @@ -255,7 +255,7 @@ void LazyEvalTransformation::on_unregister() noexcept { | |||
| DeviceStorage::make(data.storage()))[0]); | |||
| } | |||
| for (auto&& lazy_val : lazy_vals) { | |||
| if (lazy_val.is<LazyEvalValue>()) { | |||
| if (lazy_val.is(m_value_type)) { | |||
| std::string repr = | |||
| ssprintf("lazy eval failed for %s", lazy_val->to_string().c_str()); | |||
| mgb_log_debug("%s", repr.c_str()); | |||
| @@ -20,7 +20,8 @@ namespace imperative { | |||
| namespace { | |||
| using ScalarRule = ValueRefList (*)(const OpDef&, Span<ValueRef>, Span<bool>); | |||
| using ScalarRule = ValueRefList (*)( | |||
| const OpDef&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&); | |||
| static std::unordered_map<Typeinfo*, ScalarRule> scalar_rules; | |||
| ValueRef make_scalar_shape(CompNode device) { | |||
| @@ -41,17 +42,22 @@ bool is_scalar_shape(ValueRef shape) { | |||
| return *shape_of_shape == ValueShape{0}; | |||
| } | |||
| template <typename T, ValueRefList (*rule)(const T&, Span<ValueRef>, Span<bool>)> | |||
| template < | |||
| typename T, | |||
| ValueRefList (*rule)( | |||
| const T&, Span<ValueRef>, Span<bool>, const Type<ScalarValue>&)> | |||
| void register_scalar_rule() { | |||
| scalar_rules[T::typeinfo()] = [](const OpDef& def, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask); | |||
| Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& value_type) { | |||
| return (*rule)(def.cast_final_safe<T>(), inputs, inputs_mask, value_type); | |||
| }; | |||
| } | |||
| template <typename TOpDef, size_t nr_inputs> | |||
| ValueRefList elemwise_rule( | |||
| const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const TOpDef& op_def, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| if constexpr (nr_inputs != 0) { | |||
| mgb_assert(inputs.size() == inputs.size(), "inputs size mismatch"); | |||
| } | |||
| @@ -63,27 +69,29 @@ ValueRefList elemwise_rule( | |||
| } | |||
| auto outputs = imperative::apply(op_def, inputs); | |||
| if (all_scalar) { | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| outputs[0] = scalar_type.make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList remove_axis_rule( | |||
| const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const RemoveAxis& remove_axis, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(!inputs_mask.item()); | |||
| bool is_scalar = inputs.item().shape()->ndim == remove_axis.axis.size(); | |||
| if (is_scalar && remove_axis.axis.size() == 1) { | |||
| return {ScalarValue::make(inputs.item())}; | |||
| return {scalar_type.make(inputs.item())}; | |||
| } | |||
| auto outputs = imperative::apply(remove_axis, inputs); | |||
| if (is_scalar) { | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| outputs[0] = scalar_type.make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList reduce_rule( | |||
| const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const Reduce& reduce, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| if (inputs.size() == 1) { | |||
| return imperative::apply(reduce, inputs); | |||
| } | |||
| @@ -91,7 +99,7 @@ ValueRefList reduce_rule( | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| if (is_scalar) { | |||
| CompNode device = *inputs[0].device(); | |||
| return {ScalarValue::make( | |||
| return {scalar_type.make( | |||
| imperative::apply(reduce, inputs[0], make_scalar_shape(device))[0])}; | |||
| } | |||
| return imperative::apply(reduce, inputs); | |||
| @@ -99,7 +107,7 @@ ValueRefList reduce_rule( | |||
| ValueRefList collective_comm_rule( | |||
| const CollectiveComm& collective_comm, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(inputs.size() == 1); | |||
| static std::unordered_set<CollectiveComm::Mode> modes = { | |||
| CollectiveComm::Mode::ALL_REDUCE_MAX, CollectiveComm::Mode::ALL_REDUCE_MIN, | |||
| @@ -110,7 +118,7 @@ ValueRefList collective_comm_rule( | |||
| return imperative::apply(collective_comm, inputs); | |||
| } | |||
| if (inputs_mask.item()) { | |||
| return {ScalarValue::make(imperative::apply(collective_comm, inputs[0])[0])}; | |||
| return {scalar_type.make(imperative::apply(collective_comm, inputs[0])[0])}; | |||
| } else { | |||
| return imperative::apply(collective_comm, inputs); | |||
| } | |||
| @@ -118,24 +126,27 @@ ValueRefList collective_comm_rule( | |||
| ValueRefList param_pack_split_rule( | |||
| const ParamPackSplit& param_pack_split, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) { | |||
| auto outputs = imperative::apply(param_pack_split, inputs); | |||
| size_t nr_outputs = outputs.size(); | |||
| mgb_assert(nr_outputs == param_pack_split.shapes.size()); | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| if (param_pack_split.shapes[i].empty()) { | |||
| outputs[i] = ScalarValue::make(outputs[i]); | |||
| outputs[i] = scalar_type.make(outputs[i]); | |||
| } | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList dot_rule(const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| return {ScalarValue::make(imperative::apply(dot, inputs)[0])}; | |||
| ValueRefList dot_rule( | |||
| const Dot& dot, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| return {scalar_type.make(imperative::apply(dot, inputs)[0])}; | |||
| } | |||
| ValueRefList add_axis_rule( | |||
| const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const AddAxis& add_axis, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(inputs.size() == 1); | |||
| if (inputs_mask.item()) { | |||
| mgb_assert(add_axis.axis[0] == 0); | |||
| @@ -151,7 +162,8 @@ ValueRefList add_axis_rule( | |||
| } | |||
| ValueRefList remote_recv_rule( | |||
| const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const RemoteRecv& remote_recv, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| if (remote_recv.shape.empty()) { | |||
| std::vector<int32_t> shape = {1}; | |||
| auto remote_recv_no_scalar = RemoteRecv::make( | |||
| @@ -167,20 +179,21 @@ ValueRefList remote_recv_rule( | |||
| ValueRefList check_no_finite_rule( | |||
| const CheckNonFinite& check_no_finite, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| Span<bool> inputs_mask, const Type<ScalarValue>& scalar_type) { | |||
| auto outputs = imperative::apply(check_no_finite, inputs); | |||
| mgb_assert(outputs.size() == inputs.size() + 1, "output size mismatch"); | |||
| outputs.back() = ScalarValue::make(outputs.back()); | |||
| outputs.back() = scalar_type.make(outputs.back()); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (inputs_mask[i]) { | |||
| outputs[i] = ScalarValue::make(outputs[i]); | |||
| outputs[i] = scalar_type.make(outputs[i]); | |||
| } | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList subtensor_rule( | |||
| const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const Subtensor& subtensor, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(inputs.size() >= 1); | |||
| auto input = inputs[0]; | |||
| bool is_scalar; | |||
| @@ -199,14 +212,14 @@ ValueRefList subtensor_rule( | |||
| } | |||
| auto outputs = imperative::apply(subtensor, inputs); | |||
| if (is_scalar) { | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| outputs[0] = scalar_type.make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } | |||
| ValueRefList get_var_shape_rule( | |||
| const GetVarShape& get_var_shape, Span<ValueRef> inputs, | |||
| Span<bool> inputs_mask) { | |||
| const GetVarShape& get_var_shape, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| bool all_scalar = true; | |||
| mgb_assert(inputs.size() >= 1); | |||
| for (auto&& input_mask : inputs_mask) { | |||
| @@ -228,11 +241,12 @@ ValueRefList get_var_shape_rule( | |||
| } | |||
| ValueRefList reshape_rule( | |||
| const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const Reshape& reshape, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(inputs.size() == 2); | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(imperative::apply( | |||
| return {scalar_type.make(imperative::apply( | |||
| reshape, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
| } else { | |||
| return imperative::apply(reshape, inputs); | |||
| @@ -240,11 +254,12 @@ ValueRefList reshape_rule( | |||
| } | |||
| ValueRefList broadcast_rule( | |||
| const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask) { | |||
| const Broadcast& broadcast, Span<ValueRef> inputs, Span<bool> inputs_mask, | |||
| const Type<ScalarValue>& scalar_type) { | |||
| mgb_assert(inputs.size() == 2); | |||
| bool is_scalar = is_scalar_shape(inputs[1]); | |||
| if (is_scalar) { | |||
| return {ScalarValue::make(imperative::apply( | |||
| return {scalar_type.make(imperative::apply( | |||
| broadcast, inputs[0], make_scalar_shape(*inputs[0].device()))[0])}; | |||
| } else { | |||
| return imperative::apply(broadcast, inputs); | |||
| @@ -299,11 +314,11 @@ struct ScalarRuleRegistry { | |||
| ValueRefList ScalarTransformation::apply_get_attr( | |||
| const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
| auto&& input = inputs.item(); | |||
| bool is_scalar = input.is<ScalarValue>(); | |||
| bool is_scalar = input.is(m_value_type); | |||
| if (!is_scalar) { | |||
| return imperative::apply(get_attr, input); | |||
| } | |||
| auto unwrapped_input = input.cast<ScalarValue>().value(); | |||
| auto unwrapped_input = input.cast(m_value_type).value(); | |||
| if (get_attr.attr() == GetAttr::Shape) { | |||
| if (!m_empty_shape) { | |||
| m_empty_shape = ShapeValue::make(); | |||
| @@ -352,7 +367,7 @@ ValueRefList ScalarTransformation::apply_transformation( | |||
| ValueRefList unwrapped_inputs(nr_inputs); | |||
| SmallVector<bool> inputs_mask(nr_inputs); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| if (auto&& scalar_value = inputs[i].as_ref<ScalarValue>()) { | |||
| if (auto&& scalar_value = inputs[i].as_ref(m_value_type)) { | |||
| unwrapped_inputs[i] = scalar_value->value(); | |||
| inputs_mask[i] = true; | |||
| } else { | |||
| @@ -364,7 +379,8 @@ ValueRefList ScalarTransformation::apply_transformation( | |||
| if (auto apply_op = op.as<ApplyOp>()) { | |||
| auto iter = scalar_rules.find(apply_op->op().dyn_typeinfo()); | |||
| if (iter != scalar_rules.end()) { | |||
| return iter->second(apply_op->op(), unwrapped_inputs, inputs_mask); | |||
| return iter->second( | |||
| apply_op->op(), unwrapped_inputs, inputs_mask, m_value_type); | |||
| } else { | |||
| // TODO: repeat op | |||
| return fallback(); | |||
| @@ -375,7 +391,7 @@ ValueRefList ScalarTransformation::apply_transformation( | |||
| CreateTensor scalar_op( | |||
| create_tensor->kind(), create_tensor->device(), | |||
| create_tensor->dtype(), scalar_shape); | |||
| return {ScalarValue::make(imperative::apply(scalar_op, inputs)[0])}; | |||
| return {m_value_type.make(imperative::apply(scalar_op, inputs)[0])}; | |||
| } else { | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| @@ -387,7 +403,7 @@ ValueRefList ScalarTransformation::apply_transformation( | |||
| bool is_scalar = inputs_mask[0]; | |||
| auto outputs = fallback(); | |||
| if (is_scalar) { | |||
| outputs[0] = ScalarValue::make(outputs[0]); | |||
| outputs[0] = m_value_type.make(outputs[0]); | |||
| } | |||
| return outputs; | |||
| } else { | |||
| @@ -160,7 +160,7 @@ ValueRefList TracingTransformation::apply_transformation( | |||
| SmallVector<TracingValue::ref_t> wrapped_inputs; | |||
| SmallVector<size_t> input_ids; | |||
| for (auto input : inputs) { | |||
| auto tracing_value = input.as_ref<TracingValue>(); | |||
| auto tracing_value = input.as_ref(m_value_type); | |||
| if (!tracing_value) { | |||
| tracing_value = | |||
| record_var(input, m_capture_as_const, VarKind::External); | |||
| @@ -208,7 +208,7 @@ ValueRefList TracingTransformation::apply_transformation( | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| auto unwrapped_input = unwrap_var(inputs[0]); | |||
| auto outputs = imperative::apply(op, unwrapped_input); | |||
| if (auto* tracing_value = inputs[0].as<TracingValue>()) { | |||
| if (auto* tracing_value = inputs[0].as(m_value_type)) { | |||
| auto& var_info = m_vars[tracing_value->id()]; | |||
| switch (get_attr->attr()) { | |||
| case GetAttr::Shape: | |||
| @@ -228,7 +228,7 @@ ValueRefList TracingTransformation::apply_transformation( | |||
| } else if (auto* trace_mark_var = op.as<TraceMarkVar>()) { | |||
| mgb_assert(inputs.size() == 1, "TraceMarkVar expects exactly one input"); | |||
| auto input = inputs[0]; | |||
| auto tracing_var = input.as_ref<TracingValue>(); | |||
| auto tracing_var = input.as_ref(m_value_type); | |||
| if (!tracing_var) { | |||
| bool is_input = trace_mark_var->mark().substr(0, 4) == "arg_" || | |||
| trace_mark_var->mark().substr(0, 6) == "kwarg_"; | |||
| @@ -247,7 +247,7 @@ ValueRefList TracingTransformation::apply_transformation( | |||
| } else if (auto* trace_name_var = op.as<RenameValue>()) { | |||
| mgb_assert(inputs.size() == 1, "RenameValue expects exactly one input"); | |||
| auto input = inputs[0]; | |||
| auto tracing_var = input.as_ref<TracingValue>(); | |||
| auto tracing_var = input.as_ref(m_value_type); | |||
| if (!tracing_var) { | |||
| tracing_var = record_var(input, m_capture_as_const, VarKind::External); | |||
| } else { | |||
| @@ -260,7 +260,7 @@ ValueRefList TracingTransformation::apply_transformation( | |||
| } else if (op.is<GetName>()) { | |||
| mgb_assert(inputs.size() == 1, "GetName expects exactly one input"); | |||
| auto input = inputs[0]; | |||
| if (auto tracing_var = input.as_ref<TracingValue>()) { | |||
| if (auto tracing_var = input.as_ref(m_value_type)) { | |||
| auto name = m_vars[tracing_var->id()].name; | |||
| if (!name.empty()) { | |||
| return {StringValue::make(name)}; | |||
| @@ -425,26 +425,12 @@ void CompiledTransformation::compile() { | |||
| } | |||
| auto& node = var_accessors[input].node; | |||
| if (input_vars.empty() && require_link && mm_io_link.node()) { | |||
| /*mgb_assert( | |||
| !input_vars.empty(), | |||
| "io-mm operator should have at least one input");*/ | |||
| auto comp_node = mm_io_link.node()->comp_node(); | |||
| // auto comp_node = input_vars[0]->comp_node(); | |||
| node = opr::VirtualDep::make({SymbolVar(node), mm_io_link}, comp_node) | |||
| .node(); | |||
| } | |||
| input_vars.push_back(node); | |||
| } | |||
| /*if (require_link && mm_io_link.node()) { | |||
| mgb_assert( | |||
| !input_vars.empty(), | |||
| "io-mm operator should have at least one input"); | |||
| auto comp_node = mm_io_link.node()->comp_node(); | |||
| // auto comp_node = input_vars[0]->comp_node(); | |||
| input_vars[0] = opr::VirtualDep::make( | |||
| {SymbolVar(input_vars[0]), mm_io_link}, comp_node) | |||
| .node(); | |||
| }*/ | |||
| VarNodeArray output_vars; | |||
| if (item.op) { | |||
| output_vars = OpDef::apply_on_var_node(*item.op, input_vars); | |||
| @@ -520,7 +506,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { | |||
| switch (var.kind) { | |||
| case VarKind::External: { | |||
| trace_assert( | |||
| !value.is<TracedValue>(), "expect external node, got internal"); | |||
| !value.is(m_value_type), "expect external node, got internal"); | |||
| if (var.bound_data) { | |||
| assert_tensor_equal(var.bound_data, value); | |||
| } else { | |||
| @@ -545,8 +531,8 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { | |||
| } | |||
| case VarKind::Internal: { | |||
| trace_assert( | |||
| value.is<TracedValue>(), "expect internal node, got external"); | |||
| auto& traced_value = value.cast<TracedValue>(); | |||
| value.is(m_value_type), "expect internal node, got external"); | |||
| auto& traced_value = value.cast(m_value_type); | |||
| trace_assert(traced_value.id() == id, "input id mismatch"); | |||
| break; | |||
| } | |||
| @@ -559,7 +545,7 @@ void CompiledTransformation::trace_input(size_t id, ValueRef value) { | |||
| } | |||
| auto CompiledTransformation::trace_output(size_t id) -> TracedValue::ref_t { | |||
| auto traced_value = TracedValue::make(id, &m_vars[id], &m_var_accessors[id]); | |||
| auto traced_value = m_value_type.make(id, &m_vars[id], &m_var_accessors[id]); | |||
| m_weak_values.push_back(traced_value); | |||
| return traced_value; | |||
| } | |||
| @@ -569,7 +555,7 @@ TraceResult::SeqItem& CompiledTransformation::next_instruction() { | |||
| return m_seq[m_pc++]; | |||
| } | |||
| ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { | |||
| ShapeValue::ref_t CompiledTransformation::TracedValue::shape() const { | |||
| if (!m_shape) { | |||
| trace_assert(m_accessor->shape_getter, "shape unreadable"); | |||
| m_shape = ShapeValue::make(ValueShape::from(m_accessor->shape_getter())); | |||
| @@ -577,14 +563,14 @@ ShapeValue::ref_t CompiledTransformation::TracedInfo::shape() const { | |||
| return m_shape; | |||
| } | |||
| DTypeValue::ref_t CompiledTransformation::TracedInfo::dtype() const { | |||
| DTypeValue::ref_t CompiledTransformation::TracedValue::dtype() const { | |||
| return m_var->dtype; | |||
| } | |||
| CompNodeValue::ref_t CompiledTransformation::TracedInfo::comp_node() const { | |||
| CompNodeValue::ref_t CompiledTransformation::TracedValue::comp_node() const { | |||
| return m_var->device; | |||
| } | |||
| auto CompiledTransformation::TracedInfo::accessor() const -> const VarAccessor& { | |||
| auto CompiledTransformation::TracedValue::accessor() const -> const VarAccessor& { | |||
| return *m_accessor; | |||
| } | |||
| @@ -605,7 +591,7 @@ ValueRefList CompiledTransformation::apply_op( | |||
| ValueRefList CompiledTransformation::apply_get_attr( | |||
| const GetAttr& get_attr, Span<ValueRef> inputs) { | |||
| if (auto* traced_value = inputs[0].as<TracedValue>()) { | |||
| if (auto* traced_value = inputs[0].as(m_value_type)) { | |||
| ValueRef output; | |||
| auto& var_accessor = traced_value->accessor(); | |||
| switch (get_attr.attr()) { | |||
| @@ -718,15 +704,11 @@ void CompiledTransformation::on_unregister() noexcept { | |||
| void CompiledTransformation::execute() { | |||
| mgb_assert(m_executable != nullptr); | |||
| m_graph_executor = std::thread([&] { | |||
| try { | |||
| m_executable->execute(); | |||
| m_executable->wait(); | |||
| } catch (...) { | |||
| auto exc = std::current_exception(); | |||
| set_exception(exc); | |||
| } | |||
| }); | |||
| { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| m_graph_status = 1; | |||
| } | |||
| m_cv.notify_all(); | |||
| } | |||
| void CompiledTransformation::wait() { | |||
| @@ -735,8 +717,9 @@ void CompiledTransformation::wait() { | |||
| } catch (...) { | |||
| } | |||
| mgb_assert(m_executable != nullptr); | |||
| m_graph_executor.join(); | |||
| m_graph_executor = {}; | |||
| std::unique_lock lock{m_mutex}; | |||
| m_cv.wait(lock, [&] { return m_graph_status == 0; }); | |||
| lock.unlock(); | |||
| for (auto&& box : m_boxes) { | |||
| box->reset(); | |||
| } | |||
| @@ -25,16 +25,16 @@ ValueRef::storage_t& ValueRef::storage() const { | |||
| return m_storage; | |||
| } | |||
| const Value* ValueRef::as(size_t typecode) const { | |||
| const Value* ValueRef::as(const IType& type) const { | |||
| auto&& storage = this->storage(); | |||
| if (storage->m_typecode != typecode) { | |||
| if (storage->type() != type) { | |||
| return nullptr; | |||
| } | |||
| return static_cast<Value*>(storage.get()); | |||
| } | |||
| bool ValueRef::is(size_t typecode) const { | |||
| return this->storage()->m_typecode == typecode; | |||
| bool ValueRef::is(const IType& type) const { | |||
| return this->storage()->type() == type; | |||
| } | |||
| TypedValueRef<DeviceValue> ValueRef::dev_tensor() const { | |||
| @@ -106,9 +106,7 @@ std::string ValueRef::raw_type() const { | |||
| if (!m_storage) { | |||
| return "null"; | |||
| } | |||
| auto& types = Value::registered_types(); | |||
| mgb_assert(types.size() > m_storage->m_typecode); | |||
| return types[m_storage->m_typecode].name(); | |||
| return m_storage->type().name(); | |||
| } | |||
| bool ValueRef::watching() const { | |||
| @@ -137,7 +135,7 @@ ValueRef ValueWeakRef::lock() { | |||
| return {strong_storage}; | |||
| } | |||
| Value::Value(size_t typecode) : m_typecode{typecode} { | |||
| Value::Value() { | |||
| m_id = nr_values++; | |||
| } | |||
| @@ -147,17 +145,6 @@ Value::~Value() { | |||
| } | |||
| } | |||
| size_t Value::register_type(std::type_index type) { | |||
| auto& types = const_cast<std::vector<std::type_index>&>(registered_types()); | |||
| types.push_back(type); | |||
| return types.size() - 1; | |||
| } | |||
| const std::vector<std::type_index>& Value::registered_types() { | |||
| static std::vector<std::type_index> sm_registered_types; | |||
| return sm_registered_types; | |||
| } | |||
| void Value::register_value(ValueRef value) { | |||
| registered_values[value.id()] = ValueWeakRef(value); | |||
| } | |||
| @@ -188,7 +175,7 @@ std::vector<ValueRef> Value::end_record_values() { | |||
| } | |||
| void Value::try_rethrow() { | |||
| if (m_typecode == ErrorValue::TYPE_CODE) { | |||
| if (type() == PrimitiveType<ErrorValue>::instance) { | |||
| auto message = static_cast<ErrorValue*>(this)->message(); | |||
| mgb_throw(MegBrainError, "invalid value: %s", message.c_str()); | |||
| } | |||
| @@ -198,13 +185,9 @@ inline void ValueRefList::init(size_t nr_elems) { | |||
| m_size = nr_elems; | |||
| if (m_size > 0) { | |||
| if (m_size == 1) { | |||
| m_data = inline_storage(); | |||
| m_data = new (inline_storage()) ValueRef(); | |||
| } else { | |||
| auto& context = Transformation::get_context(); | |||
| m_data = context.allocator.allocate(m_size); | |||
| } | |||
| for (size_t i = 0; i < m_size; ++i) { | |||
| new (m_data + i) ValueRef(); | |||
| m_data = new ValueRef[m_size]; | |||
| } | |||
| } else { | |||
| m_data = nullptr; | |||
| @@ -215,9 +198,6 @@ ValueRefList::ValueRefList(size_t nr_elems) { | |||
| init(nr_elems); | |||
| } | |||
| /*ValueRefList::ValueRefList(std::initializer_list<ValueRef> values) | |||
| : ValueRefList(values.begin(), values.end()) {}*/ | |||
| ValueRefList::ValueRefList(const ValueRefList& rhs) | |||
| : ValueRefList(rhs.cbegin(), rhs.cend()) {} | |||
| @@ -271,14 +251,12 @@ ValueRefList::~ValueRefList() { | |||
| } | |||
| void ValueRefList::clear() { | |||
| for (size_t i = 0; i < m_size; ++i) { | |||
| m_data[i].~ValueRef(); | |||
| } | |||
| if (m_data) { | |||
| if (m_size != 1) { | |||
| Transformation::get_context().allocator.deallocate(m_data, m_size); | |||
| delete[] m_data; | |||
| } else { | |||
| mgb_assert(m_data == inline_storage()); | |||
| m_data->~ValueRef(); | |||
| } | |||
| } | |||
| m_data = nullptr; | |||
| @@ -25,79 +25,68 @@ class GradKey; | |||
| using GenericFunction = std::function<ValueRefList(Span<ValueRef>)>; | |||
| class ShapeValue final | |||
| : public MixinValueImpl<ShapeValue, ValueKind::Primitive, ValueShape> { | |||
| class ShapeValue final : public PrimitiveValue<ShapeValue, ValueShape> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| class CompNodeValue final | |||
| : public MixinValueImpl<CompNodeValue, ValueKind::Primitive, CompNode> { | |||
| class CompNodeValue final : public PrimitiveValue<CompNodeValue, CompNode> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| // TODO: override factory method | |||
| class BoolValue final : public ValueImpl<BoolValue, ValueKind::Primitive> { | |||
| class Boolean { | |||
| private: | |||
| std::optional<bool> m_value; | |||
| bool m_value; | |||
| public: | |||
| BoolValue(bool value) : m_value{value} {} | |||
| operator bool() const { return *m_value; } | |||
| Boolean() = default; | |||
| Boolean(bool value) : m_value(value) {} | |||
| std::string to_string() const override; | |||
| operator bool() const { return m_value; } | |||
| }; | |||
| void clear() override { m_value.reset(); } | |||
| // TODO: override factory method | |||
| class BoolValue final : public PrimitiveValue<BoolValue, Boolean> { | |||
| public: | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| class HostStorage final | |||
| : public MixinValueImpl<HostStorage, ValueKind::Primitive, HostTensorStorage> { | |||
| class HostStorage final : public PrimitiveValue<HostStorage, HostTensorStorage> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| class DeviceStorage final | |||
| : public MixinValueImpl< | |||
| DeviceStorage, ValueKind::Primitive, DeviceTensorStorage> { | |||
| class DeviceStorage final : public PrimitiveValue<DeviceStorage, DeviceTensorStorage> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| /** | |||
| * \brief like HostTensorND mixin, but allow scalar value | |||
| * | |||
| */ | |||
| class HostValue final : public ValueImpl<HostValue, ValueKind::Primitive> { | |||
| class HostTensor { | |||
| private: | |||
| DType m_dtype; | |||
| ValueShape m_shape; | |||
| HostTensorStorage m_storage; | |||
| public: | |||
| HostValue(DType dtype, ValueShape shape, HostTensorStorage storage) | |||
| HostTensor() = default; | |||
| HostTensor(DType dtype, ValueShape shape, HostTensorStorage storage) | |||
| : m_dtype(dtype), m_shape(shape), m_storage(storage) {} | |||
| HostValue(HostTensorND value) | |||
| : HostValue( | |||
| HostTensor(HostTensorND value) | |||
| : HostTensor( | |||
| value.dtype(), ValueShape::from(value.shape()), value.storage()) { | |||
| } | |||
| std::string to_string() const override; | |||
| void clear() override { | |||
| m_dtype = {}; | |||
| m_shape = {}; | |||
| m_storage = {}; | |||
| } | |||
| DType dtype() const { return m_dtype; } | |||
| const ValueShape& shape() const { return m_shape; } | |||
| CompNode device() const { return m_storage.comp_node(); } | |||
| @@ -112,31 +101,31 @@ public: | |||
| }; | |||
| /** | |||
| * \brief like DeviceTensorND mixin, but allow scalar value | |||
| * \brief like HostTensorND mixin, but allow scalar value | |||
| * | |||
| */ | |||
| class DeviceValue final : public ValueImpl<DeviceValue, ValueKind::Primitive> { | |||
| class HostValue final : public PrimitiveValue<HostValue, HostTensor> { | |||
| public: | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| class DeviceTensor { | |||
| private: | |||
| DType m_dtype; | |||
| ValueShape m_shape; | |||
| DeviceTensorStorage m_storage; | |||
| public: | |||
| DeviceValue(DType dtype, ValueShape shape, DeviceTensorStorage storage) | |||
| DeviceTensor() = default; | |||
| DeviceTensor(DType dtype, ValueShape shape, DeviceTensorStorage storage) | |||
| : m_dtype(dtype), m_shape(shape), m_storage(std::move(storage)) {} | |||
| DeviceValue(const DeviceTensorND& value) | |||
| : DeviceValue( | |||
| DeviceTensor(const DeviceTensorND& value) | |||
| : DeviceTensor( | |||
| value.dtype(), ValueShape::from(value.shape()), value.storage()) { | |||
| } | |||
| std::string to_string() const override; | |||
| void clear() override { | |||
| m_dtype = {}; | |||
| m_shape = {}; | |||
| m_storage = {}; | |||
| } | |||
| DType dtype() const { return m_dtype; } | |||
| const ValueShape& shape() const { return m_shape; } | |||
| CompNode device() const { return m_storage.comp_node(); } | |||
| @@ -145,26 +134,34 @@ public: | |||
| DeviceTensorND as_nd(bool allow_scalar = false) const; | |||
| }; | |||
| class FunctionValue final | |||
| : public MixinValueImpl<FunctionValue, ValueKind::Primitive, GenericFunction> { | |||
| /** | |||
| * \brief like DeviceTensorND mixin, but allow scalar value | |||
| * | |||
| */ | |||
| class DeviceValue final : public PrimitiveValue<DeviceValue, DeviceTensor> { | |||
| public: | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| class FunctionValue final : public PrimitiveValue<FunctionValue, GenericFunction> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| class DTypeValue final | |||
| : public MixinValueImpl<DTypeValue, ValueKind::Primitive, DType> { | |||
| class DTypeValue final : public PrimitiveValue<DTypeValue, DType> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| class StringValue final | |||
| : public MixinValueImpl<StringValue, ValueKind::Primitive, std::string> { | |||
| class StringValue final : public PrimitiveValue<StringValue, std::string> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| @@ -180,10 +177,9 @@ public: | |||
| std::string message() const { return m_message; } | |||
| }; | |||
| class ErrorValue final | |||
| : public MixinValueImpl<ErrorValue, ValueKind::Primitive, Error> { | |||
| class ErrorValue final : public PrimitiveValue<ErrorValue, Error> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override; | |||
| }; | |||
| @@ -57,7 +57,7 @@ struct Subgraph { | |||
| SmallVector<expr_t> exprs; | |||
| template <typename T, typename F, typename C> | |||
| SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
| SmallVector<T> apply(Span<T> input_vars, F&& f, C&& c) const { | |||
| std::unordered_map<size_t, T> idx2var; | |||
| mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||
| @@ -71,8 +71,7 @@ struct Subgraph { | |||
| for (auto idx : expr.inputs) { | |||
| expr_inputs.push_back(idx2var[idx]); | |||
| } | |||
| SmallVector<T> expr_outputs = | |||
| f(expr.op, std::move(expr_inputs), expr.outputs.size()); | |||
| SmallVector<T> expr_outputs = f(expr.op, expr_inputs, expr.outputs.size()); | |||
| mgb_assert( | |||
| expr_outputs.size() == expr.outputs.size(), "output size mismatch"); | |||
| for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||
| @@ -102,9 +101,9 @@ struct EncodedSubgraph { | |||
| SmallVector<bool> input_mask; | |||
| SmallVector<bool> output_mask; | |||
| template <typename TContainer> | |||
| TContainer encode_inputs(TContainer inputs) const { | |||
| TContainer encoded_inputs; | |||
| template <typename T> | |||
| SmallVector<T> encode_inputs(Span<T> inputs) const { | |||
| SmallVector<T> encoded_inputs; | |||
| size_t index = 0; | |||
| for (auto&& input : inputs) { | |||
| mgb_assert(index < input_mask.size(), "index out of range"); | |||
| @@ -116,9 +115,9 @@ struct EncodedSubgraph { | |||
| return encoded_inputs; | |||
| } | |||
| template <typename TContainer> | |||
| TContainer encode_outputs(TContainer outputs) const { | |||
| TContainer encoded_outputs; | |||
| template <typename T> | |||
| SmallVector<T> encode_outputs(Span<T> outputs) const { | |||
| SmallVector<T> encoded_outputs; | |||
| size_t index = 0; | |||
| for (auto&& output : outputs) { | |||
| mgb_assert(index < output_mask.size(), "index out of range"); | |||
| @@ -130,9 +129,9 @@ struct EncodedSubgraph { | |||
| return encoded_outputs; | |||
| } | |||
| template <typename TContainer> | |||
| TContainer decode_outputs(TContainer outputs) const { | |||
| TContainer decoded_outputs; | |||
| template <typename T> | |||
| SmallVector<T> decode_outputs(Span<T> outputs) const { | |||
| SmallVector<T> decoded_outputs; | |||
| size_t index = 0; | |||
| for (size_t i = 0; i < output_mask.size(); i++) { | |||
| mgb_assert(index < output_mask.size(), "index out of range"); | |||
| @@ -150,8 +149,8 @@ struct EncodedSubgraph { | |||
| EncodedSubgraph result; | |||
| result.input_mask = graph.gen_input_mask(); | |||
| result.output_mask = graph.gen_output_mask(); | |||
| graph.inputs = result.encode_inputs(graph.inputs); | |||
| graph.outputs = result.encode_outputs(graph.outputs); | |||
| graph.inputs = result.encode_inputs<Subgraph::var_t>(graph.inputs); | |||
| graph.outputs = result.encode_outputs<Subgraph::var_t>(graph.outputs); | |||
| result.graph = graph; | |||
| return result; | |||
| } | |||
| @@ -179,11 +178,11 @@ struct EncodedSubgraph { | |||
| } | |||
| template <typename T, typename F, typename C> | |||
| SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
| auto encoded_inputs = encode_inputs(input_vars); | |||
| SmallVector<T> apply(Span<T> input_vars, F&& f, C&& c) const { | |||
| auto encoded_inputs = encode_inputs<T>(input_vars); | |||
| auto encoded_outputs = | |||
| graph.apply(encoded_inputs, std::forward<F>(f), std::forward<C>(c)); | |||
| return decode_outputs(encoded_outputs); | |||
| graph.apply<T>(encoded_inputs, std::forward<F>(f), std::forward<C>(c)); | |||
| return decode_outputs<T>(encoded_outputs); | |||
| } | |||
| std::string repr() const; | |||
| @@ -280,4 +279,4 @@ public: | |||
| }; | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| } // namespace mgb | |||
| @@ -18,7 +18,7 @@ | |||
| namespace mgb::imperative { | |||
| struct InterpreterInfo { | |||
| class InterpreterValue final : public ObjectValue<InterpreterValue> { | |||
| public: | |||
| using Handle = interpreter::Interpreter::Handle; | |||
| using Channel = interpreter::Interpreter::Channel; | |||
| @@ -46,8 +46,7 @@ private: | |||
| mutable ShapeValue::ref_t m_shape; | |||
| public: | |||
| InterpreterInfo() = default; | |||
| InterpreterInfo(LocalPtr<RAIIHandle> handle, std::string name = {}) | |||
| InterpreterValue(LocalPtr<RAIIHandle> handle, std::string name = {}) | |||
| : m_handle(handle), m_name(name) {} | |||
| const LocalPtr<RAIIHandle>& handle() const { return m_handle; } | |||
| @@ -57,18 +56,14 @@ public: | |||
| ShapeValue::ref_t shape() const; | |||
| std::string name() const { return m_name; } | |||
| }; | |||
| class InterpreterValue final | |||
| : public MixinValueImpl<InterpreterValue, ValueKind::Object, InterpreterInfo> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| std::string to_string() const override { | |||
| return ssprintf( | |||
| "Handle{ptr=%p, name=%s}", handle().get(), | |||
| imperative::quoted(name()).c_str()); | |||
| } | |||
| void clear() override { m_handle = {}; } | |||
| }; | |||
| /** | |||
| @@ -82,11 +77,12 @@ class InterpreterTransformation final : public Transformation { | |||
| public: | |||
| using Interpreter = interpreter::Interpreter; | |||
| using Handle = Interpreter::Handle; | |||
| using SharedHandle = LocalPtr<InterpreterInfo::RAIIHandle>; | |||
| using SharedHandle = LocalPtr<InterpreterValue::RAIIHandle>; | |||
| using Channel = Interpreter::Channel; | |||
| private: | |||
| std::shared_ptr<Channel> m_channel; | |||
| ObjectType<InterpreterValue> m_value_type{"InterpreterValue"}; | |||
| public: | |||
| explicit InterpreterTransformation(std::shared_ptr<Channel> channel) | |||
| @@ -105,7 +101,7 @@ public: | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| mgb_assert(!value.is<InterpreterValue>()); | |||
| mgb_assert(!value.is(m_value_type)); | |||
| return value; | |||
| } | |||
| @@ -34,7 +34,8 @@ struct BackwardGraphWithClosure { | |||
| std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, | |||
| std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs); | |||
| void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||
| void operator()( | |||
| Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver); | |||
| bool input_has_grad(size_t i) { return backward_graph->input_has_grad[i]; } | |||
| @@ -51,7 +52,7 @@ struct CustomBackward; | |||
| using GradRuleFn = std::function<ValueRefList(Span<ValueRef> inputs, CustomBackward&)>; | |||
| struct CustomBackward { | |||
| using BackwardFn = std::function<ValueRefList(Span<ValueRef>)>; | |||
| using BackwardFn = std::function<SmallVector<ValueRef>(Span<ValueRef>)>; | |||
| using BackwardRule = std::function<std::optional<ValueRefList>( | |||
| const OpDef&, Span<ValueRef>, Span<bool>, CustomBackward&)>; | |||
| BackwardFn m_backward; | |||
| @@ -62,7 +63,8 @@ struct CustomBackward { | |||
| SmallVector<OutputAttr> m_output_attrs; | |||
| public: | |||
| void operator()(ValueRefList grads, std::function<void(size_t, ValueRef)> receiver); | |||
| void operator()( | |||
| Span<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver); | |||
| bool input_has_grad(size_t i) { return m_input_has_grad[i]; } | |||
| bool output_requires_grad(size_t i) { return m_output_attrs[i].requires_grad; } | |||
| @@ -175,7 +177,7 @@ inline GradSlot* GradSlotPtr::operator->() const { | |||
| return &m_fn->m_slots[m_index]; | |||
| } | |||
| class GradValue final : public ValueImpl<GradValue, ValueKind::Object> { | |||
| class GradValue final : public ObjectValue<GradValue> { | |||
| private: | |||
| ValueRef m_value; | |||
| std::shared_ptr<GradKey> m_key; | |||
| @@ -187,14 +189,9 @@ public: | |||
| std::string to_string() const override; | |||
| bool has_key(const std::shared_ptr<GradKey>& key) const { return m_key == key; } | |||
| const GradSlotPtr& slot() const { return m_slot; } | |||
| const GradSlotPtr& slot_for(std::shared_ptr<GradKey> key) const { | |||
| mgb_assert(m_key == key); | |||
| return m_slot; | |||
| } | |||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||
| // std::shared_ptr<GradKey> key() const { return m_key; } | |||
| void clear() override { | |||
| m_slot = {}; | |||
| @@ -216,9 +213,12 @@ private: | |||
| std::vector<std::pair<LocalWeakPtr<GradFn>, std::shared_ptr<OpDef>>> m_tape; | |||
| std::vector<std::pair<LocalPtr<GradFn>, std::shared_ptr<OpDef>>> m_frozen_tape; | |||
| bool m_frozen = false; | |||
| const Type<GradValue>& m_value_type; | |||
| public: | |||
| GradKey() { m_tape.reserve(4 * 1024); } | |||
| GradKey(const Type<GradValue>& value_type) : m_value_type(value_type) { | |||
| m_tape.reserve(4 * 1024); | |||
| } | |||
| void backward(); | |||
| GradValue::ref_t attach(ValueRef tensor, std::function<void(ValueRef)> callback); | |||
| @@ -230,10 +230,9 @@ public: | |||
| }; | |||
| class GradKeyValue final | |||
| : public MixinValueImpl< | |||
| GradKeyValue, ValueKind::Primitive, std::shared_ptr<GradKey>> { | |||
| : public PrimitiveValue<GradKeyValue, std::shared_ptr<GradKey>> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| using PrimitiveValue::PrimitiveValue; | |||
| std::string to_string() const override { | |||
| return ssprintf("GradKey{%s}", (*this)->name().c_str()); | |||
| @@ -242,26 +241,20 @@ public: | |||
| class GradTransformation final : public Transformation { | |||
| private: | |||
| ObjectType<GradValue> m_value_type{"GradValue"}; | |||
| std::shared_ptr<GradKey> m_key; | |||
| std::vector<GradValue::weak_ref_t> m_weak_values; | |||
| size_t m_suppressed = 0; | |||
| public: | |||
| GradTransformation(std::shared_ptr<GradKey> key) : m_key(key) {} | |||
| GradTransformation() { m_key = std::make_shared<GradKey>(m_value_type); } | |||
| auto record_grad(GradValue::ref_t tensor) { | |||
| m_weak_values.push_back(tensor); | |||
| return tensor; | |||
| } | |||
| bool is_grad_value(const ValueRef& value) { | |||
| if (auto* grad_value = value.as<GradValue>()) { | |||
| if (grad_value->has_key(m_key)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool is_grad_value(const ValueRef& value) { return value.is(m_value_type); } | |||
| /** | |||
| * \brief test whether value is related to this GradTransformation | |||
| @@ -273,13 +266,7 @@ public: | |||
| * \return GradValue::ref_t | |||
| */ | |||
| const GradValue::ref_t& as_grad_value(const ValueRef& value) { | |||
| auto&& grad_value = value.as_ref<GradValue>(); | |||
| if (grad_value) { | |||
| if (grad_value->has_key(m_key)) { | |||
| return grad_value; | |||
| } | |||
| } | |||
| return GradValue::ref_t::nil; | |||
| return value.as_ref(m_value_type); | |||
| } | |||
| bool has_key(std::shared_ptr<GradKey> key) { | |||
| @@ -299,6 +286,8 @@ public: | |||
| return value; | |||
| } | |||
| const std::shared_ptr<GradKey>& key() const { return m_key; } | |||
| std::string name() const override { return "GradTransformation"; } | |||
| GenericFunction make_backward_closure(Span<ValueRef> ys); | |||
| @@ -22,32 +22,27 @@ | |||
| namespace mgb::imperative { | |||
| class LazyEvalInfo { | |||
| class LazyEvalValue final : public ObjectValue<LazyEvalValue> { | |||
| private: | |||
| VarNode* m_node = nullptr; | |||
| ValueRef m_bound_data; | |||
| std::string m_name; | |||
| public: | |||
| LazyEvalInfo() = default; | |||
| LazyEvalInfo(VarNode* node, ValueRef bound_data, std::string name) | |||
| LazyEvalValue(VarNode* node, ValueRef bound_data, std::string name) | |||
| : m_node(node), m_bound_data(bound_data), m_name(name) {} | |||
| VarNode* node() const { return m_node; } | |||
| ValueRef bound_data() const { return m_bound_data; } | |||
| std::string name() const { return m_name; } | |||
| }; | |||
| class LazyEvalValue final | |||
| : public MixinValueImpl<LazyEvalValue, ValueKind::Object, LazyEvalInfo> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| std::string to_string() const override { | |||
| return ssprintf( | |||
| "LazyEvalValue{node=%p, name=%s}", node(), node()->name().c_str()); | |||
| } | |||
| void clear() override {} | |||
| }; | |||
| /** | |||
| @@ -67,6 +62,7 @@ private: | |||
| std::vector<LazyEvalValue::weak_ref_t> m_weak_vars; | |||
| SymbolVar m_io_link = nullptr; | |||
| std::exception_ptr m_graph_exc; | |||
| ObjectType<LazyEvalValue> m_value_type{"LazyEvalValue"}; | |||
| public: | |||
| LazyEvalTransformation(bool no_exec) : m_no_exec(no_exec) { | |||
| @@ -75,7 +71,7 @@ public: | |||
| LazyEvalValue::ref_t record_var( | |||
| VarNode* node, ValueRef bound_data = {}, std::string name = {}) { | |||
| auto lazy_eval_val = LazyEvalValue::make(node, bound_data, name); | |||
| auto lazy_eval_val = m_value_type.make(node, bound_data, name); | |||
| m_weak_vars.push_back(lazy_eval_val); | |||
| return lazy_eval_val; | |||
| } | |||
| @@ -86,7 +82,7 @@ public: | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| mgb_assert(!value.is<LazyEvalValue>()); | |||
| mgb_assert(!value.is(m_value_type)); | |||
| return value; | |||
| } | |||
| @@ -17,7 +17,7 @@ | |||
| namespace mgb::imperative { | |||
| class ScalarValue final : public ValueImpl<ScalarValue, ValueKind::Object> { | |||
| class ScalarValue final : public ObjectValue<ScalarValue> { | |||
| private: | |||
| ValueRef m_value; | |||
| @@ -47,17 +47,21 @@ public: | |||
| class ScalarTransformation final : public Transformation { | |||
| private: | |||
| ShapeValue::ref_t m_empty_shape; // [] | |||
| ObjectType<ScalarValue> m_value_type{"ScalarValue"}; | |||
| public: | |||
| ValueRefList apply_get_attr(const GetAttr& get_attr, Span<ValueRef> inputs); | |||
| ValueRefList apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| mgb_assert(!value.is<ScalarValue>()); | |||
| mgb_assert(!value.is(m_value_type)); | |||
| return value; | |||
| } | |||
| std::string name() const override { return "ScalarTransformation"; } | |||
| const Type<ScalarValue>& value_type() const { return m_value_type; } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -22,7 +22,7 @@ | |||
| namespace mgb::imperative { | |||
| class SymbolValue final : public ValueImpl<SymbolValue, ValueKind::Object> { | |||
| class SymbolValue final : public ObjectValue<SymbolValue> { | |||
| private: | |||
| VarNode* m_node = nullptr; | |||
| @@ -47,6 +47,7 @@ public: | |||
| class SymbolTransformation final : public Transformation { | |||
| private: | |||
| ComputingGraph* m_graph = nullptr; | |||
| ObjectType<SymbolValue> m_value_type{"SymbolValue"}; | |||
| public: | |||
| SymbolTransformation(ComputingGraph* graph) : m_graph(graph) {} | |||
| @@ -55,12 +56,12 @@ public: | |||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||
| SmallVector<VarNode*> input_nodes; | |||
| for (auto&& input : inputs) { | |||
| input_nodes.push_back(input.cast<SymbolValue>().node()); | |||
| input_nodes.push_back(input.cast(m_value_type).node()); | |||
| } | |||
| auto output_nodes = OpDef::apply_on_var_node(apply_op->op(), input_nodes); | |||
| ValueRefList outputs(output_nodes.size()); | |||
| for (size_t i = 0; i < output_nodes.size(); ++i) { | |||
| outputs[i] = SymbolValue::make(output_nodes[i]); | |||
| outputs[i] = m_value_type.make(output_nodes[i]); | |||
| } | |||
| return outputs; | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| @@ -69,9 +70,9 @@ public: | |||
| args.kind == CreateTensor::Const, | |||
| "only const value is allowed here"); | |||
| auto* node = opr::ImmutableTensor::make(*m_graph, *args.host, {}).node(); | |||
| return {SymbolValue::make(node)}; | |||
| return {m_value_type.make(node)}; | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| auto* node = inputs.as_array<1>()[0].cast<SymbolValue>().node(); | |||
| auto* node = inputs.item().cast(m_value_type).node(); | |||
| switch (get_attr->attr()) { | |||
| case GetAttr::DType: | |||
| return {DTypeValue::make(node->dtype())}; | |||
| @@ -121,11 +122,13 @@ public: | |||
| } | |||
| ValueRef unwrap(ValueRef value) override { | |||
| mgb_assert(!value.is<SymbolValue>(), "SymbolValue doesn't support unwrap"); | |||
| mgb_assert(!value.is(m_value_type), "SymbolValue doesn't support unwrap"); | |||
| return value; | |||
| } | |||
| std::string name() const override { return "SymbolTransformation"; } | |||
| const Type<SymbolValue>& value_type() const { return m_value_type; } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -100,22 +100,15 @@ public: | |||
| } | |||
| }; | |||
| class TracingInfo { | |||
| class TracingValue final : public ObjectValue<TracingValue> { | |||
| private: | |||
| ValueRef m_value = {}; | |||
| size_t m_id = 0; | |||
| public: | |||
| TracingInfo() = default; | |||
| TracingInfo(ValueRef value, size_t id) : m_value(value), m_id(id) {} | |||
| TracingValue(ValueRef value, size_t id) : m_value(value), m_id(id) {} | |||
| ValueRef value() const { return m_value; } | |||
| size_t id() const { return m_id; } | |||
| }; | |||
| class TracingValue final | |||
| : public MixinValueImpl<TracingValue, ValueKind::Object, TracingInfo> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| std::string to_string() const override { | |||
| return ssprintf( | |||
| @@ -126,6 +119,8 @@ public: | |||
| void on_watch() override { value().watch(); } | |||
| void on_unwatch() override { value().unwatch(); } | |||
| void clear() override { m_value = {}; } | |||
| }; | |||
| /** | |||
| @@ -146,6 +141,7 @@ private: | |||
| std::vector<TracingValue::weak_ref_t> m_weak_vars; | |||
| bool m_capture_as_const = false; | |||
| bool m_record_input_shapes = false; | |||
| ObjectType<TracingValue> m_value_type{"TracingValue"}; | |||
| public: | |||
| TracingTransformation(bool capture_as_const, bool record_input_shapes) | |||
| @@ -162,7 +158,7 @@ public: | |||
| */ | |||
| TypedValueRef<TracingValue> record_var(ValueRef value, bool capture, VarKind kind) { | |||
| size_t id = m_vars.size(); | |||
| auto wrapped_value = TracingValue::make(value, id); | |||
| auto wrapped_value = m_value_type.make(value, id); | |||
| m_vars.push_back({id, value.dtype(), value.device()}); | |||
| auto& var = m_vars.back(); | |||
| if (capture) { | |||
| @@ -179,7 +175,7 @@ public: | |||
| return wrapped_value; | |||
| } | |||
| ValueRef unwrap_var(ValueRef value) { | |||
| if (auto* tracing_value = value.as<TracingValue>()) { | |||
| if (auto* tracing_value = value.as(m_value_type)) { | |||
| return tracing_value->value(); | |||
| } | |||
| return value; | |||
| @@ -189,7 +185,7 @@ public: | |||
| const Operator& op, Span<ValueRef> inputs) override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| if (auto* tracing_value = value.as<TracingValue>()) { | |||
| if (auto* tracing_value = value.as(m_value_type)) { | |||
| return tracing_value->value(); | |||
| } | |||
| return value; | |||
| @@ -234,7 +230,7 @@ public: | |||
| std::function<void(std::exception_ptr)> exc_setter; | |||
| }; | |||
| class TracedInfo { | |||
| class TracedValue final : public ObjectValue<TracedValue> { | |||
| private: | |||
| size_t m_id = 0; | |||
| VarInfo* m_var = nullptr; | |||
| @@ -244,8 +240,7 @@ public: | |||
| mutable CompNodeValue::ref_t m_comp_node; | |||
| public: | |||
| TracedInfo() = default; | |||
| TracedInfo(size_t id, VarInfo* var, VarAccessor* accessor) | |||
| TracedValue(size_t id, VarInfo* var, VarAccessor* accessor) | |||
| : m_id(id), m_var(var), m_accessor(accessor) {} | |||
| size_t id() const { return m_id; } | |||
| ShapeValue::ref_t shape() const; | |||
| @@ -256,16 +251,12 @@ public: | |||
| void set_exception(std::exception_ptr exc) const { | |||
| m_accessor->exc_setter(exc); | |||
| } | |||
| }; | |||
| class TracedValue final | |||
| : public MixinValueImpl<TracedValue, ValueKind::Object, TracedInfo> { | |||
| public: | |||
| using MixinValueImpl::MixinValueImpl; | |||
| std::string to_string() const override { | |||
| return ssprintf("TracedValue{\"id\"=%zu}", id()); | |||
| } | |||
| void clear() override {} | |||
| }; | |||
| private: | |||
| @@ -280,9 +271,12 @@ private: | |||
| std::function<bool(ValueRef, ValueRef)> m_value_comparator; | |||
| bool m_input_shape_static; | |||
| std::mutex m_mutex; | |||
| std::condition_variable m_cv; | |||
| std::exception_ptr m_graph_exc; | |||
| int m_graph_status = 0; // 0 = stop, 1 = running, 2 = finalizing | |||
| std::vector<std::shared_ptr<BoxBase>> m_boxes; | |||
| ComputingGraph::OutputSpec m_output_spec; | |||
| ObjectType<TracedValue> m_value_type{"TracedValue"}; | |||
| public: | |||
| CompiledTransformation(TraceResult result, bool input_shape_static) | |||
| @@ -292,6 +286,27 @@ public: | |||
| m_graph = ComputingGraph::make(); | |||
| options().no_force_inplace = true; | |||
| options().async_exec_level = 0b100; | |||
| m_graph_executor = std::thread([&] { | |||
| while (true) { | |||
| std::unique_lock lock{m_mutex}; | |||
| m_cv.wait(lock, [&] { return m_graph_status != 0; }); | |||
| lock.unlock(); | |||
| if (m_graph_status == 2) { | |||
| break; | |||
| } | |||
| try { | |||
| m_executable->execute(); | |||
| m_executable->wait(); | |||
| } catch (...) { | |||
| auto exc = std::current_exception(); | |||
| set_exception(exc); | |||
| } | |||
| lock.lock(); | |||
| m_graph_status = 0; | |||
| lock.unlock(); | |||
| m_cv.notify_all(); | |||
| } | |||
| }); | |||
| } | |||
| ComputingGraph& graph() { return *m_graph; } | |||
| @@ -350,7 +365,7 @@ public: | |||
| void on_unregister() noexcept override; | |||
| ValueRef unwrap(ValueRef value) override { | |||
| mgb_assert(!value.is<TracedValue>()); | |||
| mgb_assert(!value.is(m_value_type)); | |||
| return value; | |||
| } | |||
| @@ -368,6 +383,15 @@ public: | |||
| m_boxes.push_back(box); | |||
| return box; | |||
| } | |||
| ~CompiledTransformation() { | |||
| { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| m_graph_status = 2; | |||
| } | |||
| m_cv.notify_all(); | |||
| m_graph_executor.join(); | |||
| } | |||
| }; | |||
| } // namespace mgb::imperative | |||
| @@ -11,7 +11,9 @@ | |||
| #pragma once | |||
| #include <optional> | |||
| #include <typeindex> | |||
| #include <vector> | |||
| #include "megbrain/utils/mempool.h" | |||
| #include "megbrain/utils/metahelper.h" | |||
| @@ -34,7 +34,7 @@ public: | |||
| Span(const T* begin, const T* end) : m_begin{begin}, m_end{end} {} | |||
| Span(const T* begin, size_t size) : Span(begin, begin + size) {} | |||
| template <typename TContainer> | |||
| Span(TContainer& container) : Span(container.data(), container.size()) {} | |||
| Span(const TContainer& container) : Span(container.data(), container.size()) {} | |||
| const T* begin() const { return m_begin; } | |||
| const T* end() const { return m_end; } | |||
| const T* data() const { return m_begin; } | |||
| @@ -2,7 +2,10 @@ | |||
| #include <chrono> | |||
| #include <iostream> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| namespace mgb { | |||
| @@ -18,7 +21,7 @@ public: | |||
| private: | |||
| clock_t::duration m_duration = clock_t::duration{0}; | |||
| size_t m_timing = 0; | |||
| const char* m_name = nullptr; | |||
| std::string m_name; | |||
| uint64_t m_count = 0; | |||
| size_t m_enabled = 1; | |||
| bool m_default_enabled = true; | |||
| @@ -42,7 +45,8 @@ private: | |||
| } | |||
| if (timer.m_enabled) { | |||
| if (!--timer.m_timing) { | |||
| timer.m_duration += (clock_t::now() - start); | |||
| auto duration = (clock_t::now() - start); | |||
| timer.m_duration += duration; | |||
| } | |||
| timer.m_count++; | |||
| } | |||
| @@ -67,13 +71,10 @@ private: | |||
| } | |||
| }; | |||
| using TimeScope = TimeScopeRecursive; | |||
| public: | |||
| Timer(const char* name, bool default_enabled); | |||
| Timer(std::string name, bool default_enabled = true); | |||
| const char* name() { return m_name; } | |||
| auto time_scope() { return TimeScope(*this); } | |||
| std::string name() { return m_name; } | |||
| auto time_scope_recursive() { return TimeScopeRecursive(*this); }; | |||
| auto enable_scope() { return EnableScope(*this); } | |||
| void reset() { | |||
| @@ -88,7 +89,14 @@ public: | |||
| } // namespace stats | |||
| struct Stats { | |||
| static inline std::vector<stats::Timer*> sm_timers; | |||
| struct TimerNode { | |||
| std::map<std::string, std::unique_ptr<TimerNode>> children; | |||
| stats::Timer* timer = nullptr; | |||
| TimerNode() {} | |||
| }; | |||
| static inline TimerNode sm_root; | |||
| // register your timers here | |||
| // for example: | |||
| @@ -97,33 +105,84 @@ struct Stats { | |||
| // | |||
| // then use MGE_TIMER_SCOPE(mytimer) to collect durations in your code | |||
| static void print() { | |||
| std::vector<const char*> unused_timers; | |||
| for (auto* timer : sm_timers) { | |||
| if (timer->count() == 0) { | |||
| unused_timers.push_back(timer->name()); | |||
| } else { | |||
| printf("%s costs %ld ns, happens %ld times\n", timer->name(), | |||
| timer->get().count(), timer->count()); | |||
| static std::pair<long, long> print_node( | |||
| std::string name, TimerNode& node, size_t indent = 0) { | |||
| auto print_indent = [&] { | |||
| for (size_t i = 0; i < indent; ++i) { | |||
| printf(" "); | |||
| } | |||
| }; | |||
| long ns = 0, count = 0; | |||
| if (auto* timer = node.timer) { | |||
| print_indent(); | |||
| printf("%s costs %'ld ns, hits %'ld times\n", name.c_str(), | |||
| (long)timer->get().count(), (long)timer->count()); | |||
| ns = timer->get().count(); | |||
| count = timer->count(); | |||
| } | |||
| if (!node.children.empty()) { | |||
| bool collect_children = node.timer == nullptr; | |||
| if (collect_children) { | |||
| print_indent(); | |||
| printf("%s:\n", name.c_str()); | |||
| } | |||
| long ns = 0, count = 0; | |||
| for (auto&& child : node.children) { | |||
| auto [child_ns, child_count] = | |||
| print_node(child.first, *child.second, indent + 4); | |||
| if (collect_children) { | |||
| ns += child_ns; | |||
| count += child_count; | |||
| } | |||
| } | |||
| if (collect_children) { | |||
| print_indent(); | |||
| printf("total costs %'ld ns, hits %'ld times\n", ns, count); | |||
| } | |||
| } | |||
| return {ns, count}; | |||
| } | |||
| if (!unused_timers.empty()) { | |||
| printf("%zu timers unused\n", unused_timers.size()); | |||
| static void print() { | |||
| for (auto&& child : sm_root.children) { | |||
| print_node(child.first, *child.second); | |||
| } | |||
| } | |||
| static void reset() { | |||
| for (auto* timer : sm_timers) { | |||
| timer->reset(); | |||
| } | |||
| auto reset_node = [](TimerNode& node, auto&& reset_node) -> void { | |||
| if (auto* timer = node.timer) { | |||
| timer->reset(); | |||
| } | |||
| for (auto&& child : node.children) { | |||
| reset_node(*child.second, reset_node); | |||
| } | |||
| }; | |||
| reset_node(sm_root, reset_node); | |||
| } | |||
| }; | |||
| inline stats::Timer::Timer(const char* name, bool default_enabled) | |||
| inline stats::Timer::Timer(std::string name, bool default_enabled) | |||
| : m_name(name), m_default_enabled(default_enabled) { | |||
| Stats::sm_timers.push_back(this); | |||
| std::vector<std::string> terms; | |||
| Stats::TimerNode* node = &Stats::sm_root; | |||
| while (true) { | |||
| auto pos = name.find("."); | |||
| if (pos == std::string::npos) { | |||
| auto& child = node->children[name]; | |||
| child = std::make_unique<Stats::TimerNode>(); | |||
| node = child.get(); | |||
| node->timer = this; | |||
| break; | |||
| } else { | |||
| auto& child = node->children[name.substr(0, pos)]; | |||
| if (!child) { | |||
| child = std::make_unique<Stats::TimerNode>(); | |||
| } | |||
| node = child.get(); | |||
| name = name.substr(pos + 1); | |||
| } | |||
| } | |||
| } | |||
| #if MGE_ENABLE_STATS | |||
| @@ -50,18 +50,70 @@ class Operator; | |||
| class ValueRefList; | |||
| /** | |||
| * \brief base class of all value types | |||
| */ | |||
| class IType : public NonCopyableObj { | |||
| private: | |||
| std::string m_name; | |||
| // TODO: count values, or make an linkedlist | |||
| public: | |||
| IType(std::string name) : m_name(std::move(name)) {} | |||
| const std::string& name() const { return m_name; } | |||
| bool operator==(const IType& rhs) const { return this == &rhs; } | |||
| bool operator!=(const IType& rhs) const { return this != &rhs; } | |||
| }; | |||
| /** | |||
| * \brief type of values. | |||
| * | |||
| * \tparam T ctype of value | |||
| */ | |||
| template <typename T> | |||
| class Type : public IType { | |||
| protected: | |||
| Type(std::string name) : IType(std::move(name)) {} | |||
| // TODO: each type owns an allocator | |||
| public: | |||
| /** | |||
| * \brief helper function for construct a value | |||
| * | |||
| * \tparam TArgs types of arguments | |||
| * \param args arguments | |||
| * \return TypedValueRef<T> reference of value | |||
| */ | |||
| template <typename... TArgs> | |||
| TypedValueRef<T> make(TArgs&&... args) const; | |||
| }; | |||
| /** | |||
| * \brief type of primitive values. | |||
| * | |||
| * \tparam T ctype of value | |||
| */ | |||
| template <typename T> | |||
| class Type { | |||
| class PrimitiveType : public Type<T> { | |||
| private: | |||
| const size_t m_code = T::TYPE_CODE; | |||
| PrimitiveType(); | |||
| public: | |||
| inline size_t code() const { return m_code; } | |||
| static inline PrimitiveType instance; | |||
| }; | |||
| enum class ValueKind { | |||
| Primitive, | |||
| Object, | |||
| /** | |||
| * \brief type of object values. | |||
| * | |||
| * \tparam T ctype of value | |||
| */ | |||
| template <typename T> | |||
| class ObjectType : public Type<T> { | |||
| public: | |||
| ObjectType(std::string name) : Type<T>(name) {} | |||
| }; | |||
| /** | |||
| @@ -71,9 +123,8 @@ enum class ValueKind { | |||
| * and only the tail node is valid. ValueRef stores a value node, and it may be | |||
| * an invalid internal node. When you dereference it, it will check its successor, | |||
| * automatically find the tail node and return. This list would be modified to reduce | |||
| * list length by change value's successor, but a ValueRef always has steady m_storage | |||
| * when not explicitly modified. | |||
| * So we use m_storage to identify a ValueRef ( hash / equility / id ). | |||
| * list length by change value's successor, but a steady id was kept in ValueRef | |||
| * so we can use it for identify a ValueRef ( hash / equility / id ). | |||
| */ | |||
| class ValueRef { | |||
| public: | |||
| @@ -93,9 +144,7 @@ private: | |||
| */ | |||
| storage_t& storage() const; | |||
| const Value* as(size_t typecode) const; | |||
| bool is(size_t typecode) const; | |||
| const Value* as(const IType& type) const; | |||
| public: | |||
| ValueRef() = default; | |||
| @@ -103,45 +152,76 @@ public: | |||
| /** | |||
| * \brief whether value is instance of target type or not | |||
| * | |||
| * \tparam TValue target type | |||
| * \return true if type of value is TValue | |||
| * \return false if empty or type of value is not TValue | |||
| * \param type target type | |||
| * \return true if type of value is instance of type | |||
| * \return false if empty or type of value is not instance of type | |||
| */ | |||
| template <typename TValue> | |||
| inline bool is(Type<TValue> type = {}) const; | |||
| bool is(const IType& type) const; | |||
| /** | |||
| * \brief try cast value as target type | |||
| * | |||
| * \tparam TValue target type | |||
| * \tparam type target type | |||
| * \return TValue* raw pointer if success, otherwise nullptr | |||
| */ | |||
| template <typename TValue> | |||
| inline const TValue* as(Type<TValue> type = {}) const; | |||
| inline const TValue* as(const Type<TValue>& type) const; | |||
| /** | |||
| * \brief cast value to target type | |||
| * | |||
| * \tparam TValue target type | |||
| * \param type target type | |||
| * \return TValue& reference of value | |||
| */ | |||
| template <typename TValue> | |||
| inline const TValue& cast(Type<TValue> type = {}) const; | |||
| inline const TValue& cast(const Type<TValue>& type) const; | |||
| /** | |||
| * \brief like as(), but returns TypedValueRef instead | |||
| * | |||
| * \tparam TValue target type | |||
| * \param type target type | |||
| * \return TypedValueRef<TValue> reference if success, otherwise empty reference | |||
| */ | |||
| template <typename TValue> | |||
| inline const TypedValueRef<TValue>& as_ref(Type<TValue> type = {}) const; | |||
| inline const TypedValueRef<TValue>& as_ref(const Type<TValue>& type) const; | |||
| /** | |||
| * \brief like cast(), but allow empty value and returns TypedValueRef instead | |||
| * | |||
| * \param type target type | |||
| * \return TypedValueRef<TValue> reference if success, otherwise empty reference | |||
| */ | |||
| template <typename TValue> | |||
| inline const TypedValueRef<TValue>& cast_ref(const Type<TValue>& type) const; | |||
| template <typename TValue> | |||
| inline std::enable_if_t<TValue::is_primitive, bool> is() const { | |||
| return is(PrimitiveType<TValue>::instance); | |||
| } | |||
| template <typename TValue> | |||
| inline std::enable_if_t<TValue::is_primitive, const TValue*> as() const { | |||
| return as(PrimitiveType<TValue>::instance); | |||
| } | |||
| template <typename TValue> | |||
| inline std::enable_if_t<TValue::is_primitive, const TValue&> cast() const { | |||
| return cast(PrimitiveType<TValue>::instance); | |||
| } | |||
| template <typename TValue> | |||
| inline const TypedValueRef<TValue>& cast_ref(Type<TValue> type = {}) const; | |||
| inline std::enable_if_t<TValue::is_primitive, const TypedValueRef<TValue>&> as_ref() | |||
| const { | |||
| return as_ref(PrimitiveType<TValue>::instance); | |||
| } | |||
| template <typename TValue> | |||
| void on_cast_failure() const; | |||
| inline std::enable_if_t<TValue::is_primitive, const TypedValueRef<TValue>&> | |||
| cast_ref() const { | |||
| return cast_ref(PrimitiveType<TValue>::instance); | |||
| } | |||
| void on_cast_failure(const IType& type) const; | |||
| operator bool() const { return bool(m_storage); } | |||
| @@ -172,8 +252,6 @@ public: | |||
| friend class ValueWeakRef; | |||
| template <typename> | |||
| friend class TypedValueRef; | |||
| template <typename, ValueKind> | |||
| friend class ValueImpl; | |||
| friend ValueRefList apply(const Operator& op, Span<ValueRef> inputs); | |||
| }; | |||
| @@ -195,7 +273,8 @@ protected: | |||
| public: | |||
| ValueWeakRef() = default; | |||
| ValueWeakRef(ValueRef value) : m_id(value.id()), m_storage(value.m_storage) {} | |||
| ValueWeakRef(const ValueRef& value) | |||
| : m_id(value.id()), m_storage(value.m_storage) {} | |||
| /** | |||
| * \brief try promote to ValueRef | |||
| @@ -218,19 +297,15 @@ public: | |||
| class Value : public NonCopyableObj { | |||
| private: | |||
| uint64_t m_id = std::numeric_limits<uint64_t>::max(); | |||
| size_t m_typecode = 0; | |||
| const IType* m_type = nullptr; | |||
| ValueRef m_successor; | |||
| size_t m_watching = 0; | |||
| protected: | |||
| Value(size_t typecode); | |||
| Value(); | |||
| public: | |||
| size_t typecode() const { return m_typecode; } | |||
| const std::type_index type() const { return registered_types()[m_typecode]; } | |||
| static size_t register_type(std::type_index type); | |||
| static const std::vector<std::type_index>& registered_types(); | |||
| const IType& type() const { return *m_type; } | |||
| static void register_value(ValueRef value); | |||
| static ValueRef get_value_by_id(uint64_t id); | |||
| @@ -251,11 +326,12 @@ public: | |||
| friend class ValueRef; | |||
| friend class ValueWeakRef; | |||
| template <typename, ValueKind> | |||
| friend class ValueImpl; | |||
| template <typename T> | |||
| friend class TypedValueRef; | |||
| template <typename T> | |||
| friend class Type; | |||
| ~Value(); | |||
| private: | |||
| @@ -267,30 +343,17 @@ private: | |||
| * | |||
| * \tparam T type of value | |||
| */ | |||
| template <typename T, ValueKind Kind> | |||
| class ValueImpl : public Value { | |||
| template <typename T> | |||
| class ObjectValue : public Value { | |||
| protected: | |||
| ValueImpl() : Value(TYPE_CODE) {} | |||
| ObjectValue() {} | |||
| public: | |||
| using ref_t = TypedValueRef<T>; | |||
| using weak_ref_t = TypedValueWeakRef<T>; | |||
| static inline const size_t TYPE_CODE = [] { return register_type(typeid(T)); }(); | |||
| static constexpr ValueKind KIND = Kind; | |||
| /** | |||
| * \brief helper function for construct a value | |||
| * | |||
| * \tparam TArgs types of arguments | |||
| * \param args arguments | |||
| * \return TypedValueRef<T> reference of value | |||
| */ | |||
| template <typename... TArgs> | |||
| static MGB_NOINLINE TypedValueRef<T> make(TArgs&&... args) { | |||
| static_assert(std::is_final_v<T>); | |||
| return ValueRef::make(LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...)); | |||
| } | |||
| static constexpr bool is_primitive = false; | |||
| static constexpr bool is_object = true; | |||
| }; | |||
| /** | |||
| @@ -299,74 +362,89 @@ public: | |||
| * \tparam T type of value | |||
| * \tparam TMixin type of mixin class | |||
| */ | |||
| template <typename T, ValueKind Kind, typename TMixin> | |||
| class MixinValueImpl : public ValueImpl<T, Kind>, public TMixin { | |||
| template <typename T, typename TMixin> | |||
| class PrimitiveValue : public Value, public TMixin { | |||
| public: | |||
| using ref_t = TypedValueRef<T>; | |||
| using weak_ref_t = TypedValueWeakRef<T>; | |||
| using TMixin::TMixin; | |||
| MixinValueImpl(TMixin mixin) : TMixin(std::move(mixin)) {} | |||
| PrimitiveValue(TMixin&& mixin) : TMixin(std::move(mixin)) {} | |||
| PrimitiveValue(const TMixin& mixin) : TMixin(mixin) {} | |||
| public: | |||
| void clear() override final { ((TMixin&)*this) = {}; } | |||
| bool eq(const TMixin& value) const { return ((const TMixin&)*this) == value; } | |||
| /** | |||
| * \brief helper function for construct a value | |||
| * | |||
| * \tparam TArgs types of arguments | |||
| * \param args arguments | |||
| * \return TypedValueRef<T> reference of value | |||
| */ | |||
| template <typename... TArgs> | |||
| static TypedValueRef<T> make(TArgs&&... args) { | |||
| return PrimitiveType<T>::instance.make(std::forward<TArgs&&>(args)...); | |||
| } | |||
| static constexpr bool is_primitive = true; | |||
| static constexpr bool is_object = false; | |||
| }; | |||
| template <typename T> | |||
| PrimitiveType<T>::PrimitiveType() : Type<T>(typeid(T).name()) { | |||
| static_assert(std::is_base_of_v<Value, T>); | |||
| static_assert(!std::is_base_of_v<ObjectValue<T>, T>); | |||
| } | |||
| inline ValueRef::ValueRef(storage_t storage) { | |||
| // mgb_assert(storage); | |||
| m_storage = storage; | |||
| m_id = m_storage->m_id; | |||
| } | |||
| template <typename TValue> | |||
| inline const TValue* ValueRef::as(Type<TValue> type) const { | |||
| // auto _ = Stats::time_value_as.time_scope(); | |||
| inline const TValue* ValueRef::as(const Type<TValue>& type) const { | |||
| static_assert(std::is_base_of_v<Value, TValue>); | |||
| return static_cast<const TValue*>(as(type.code())); | |||
| return static_cast<const TValue*>(as((const IType&)type)); | |||
| } | |||
| template <typename TValue> | |||
| inline const TValue& ValueRef::cast(Type<TValue> type) const { | |||
| // auto _ = Stats::time_value_cast.time_scope(); | |||
| inline const TValue& ValueRef::cast(const Type<TValue>& type) const { | |||
| auto* ptr = as<TValue>(type); | |||
| if (mgb_unlikely(!ptr)) { | |||
| on_cast_failure<TValue>(); | |||
| on_cast_failure(type); | |||
| } | |||
| return static_cast<const TValue&>(*ptr); | |||
| } | |||
| template <typename TValue> | |||
| inline bool ValueRef::is(Type<TValue> type) const { | |||
| // auto _ = Stats::time_value_is.time_scope(); | |||
| return is(type.code()); | |||
| } | |||
| template <typename TValue> | |||
| inline const TypedValueRef<TValue>& ValueRef::as_ref(Type<TValue> type) const { | |||
| if (!is<TValue>(type)) { | |||
| inline const TypedValueRef<TValue>& ValueRef::as_ref(const Type<TValue>& type) const { | |||
| if (!is(type)) { | |||
| return TypedValueRef<TValue>::nil; | |||
| } | |||
| return *reinterpret_cast<const TypedValueRef<TValue>*>(this); | |||
| } | |||
| template <typename TValue> | |||
| inline const TypedValueRef<TValue>& ValueRef::cast_ref(Type<TValue> type) const { | |||
| inline const TypedValueRef<TValue>& ValueRef::cast_ref(const Type<TValue>& type) const { | |||
| if (!m_storage) { | |||
| return TypedValueRef<TValue>::nil; | |||
| } | |||
| if (mgb_unlikely(!is<TValue>(type))) { | |||
| on_cast_failure<TValue>(); | |||
| if (mgb_unlikely(!is(type))) { | |||
| on_cast_failure(type); | |||
| } | |||
| return *reinterpret_cast<const TypedValueRef<TValue>*>(this); | |||
| } | |||
| template <typename TValue> | |||
| void ValueRef::on_cast_failure() const { | |||
| inline void ValueRef::on_cast_failure(const IType& type) const { | |||
| // if this is ErrorValue, rethrow directly | |||
| storage()->try_rethrow(); | |||
| mgb_assert( | |||
| storage()->m_typecode != TValue::TYPE_CODE, "expect type %s, got %s", | |||
| typeid(TValue).name(), to_string().c_str()); | |||
| storage()->type() != type, "expect type %s, got %s", type.name().c_str(), | |||
| to_string().c_str()); | |||
| } | |||
| /** | |||
| @@ -382,26 +460,10 @@ private: | |||
| public: | |||
| TypedValueRef() = default; | |||
| const T& operator*() const { | |||
| if constexpr (T::KIND == ValueKind::Object) { | |||
| return this->template cast<T>(); | |||
| } else if constexpr (T::KIND == ValueKind::Primitive) { | |||
| if (!m_storage) { | |||
| on_cast_failure<T>(); | |||
| } | |||
| return static_cast<const T&>(*m_storage); | |||
| } else { | |||
| static_assert(!std::is_same_v<T, T>); | |||
| } | |||
| } | |||
| const T* operator->() const { | |||
| if constexpr (T::KIND == ValueKind::Object) { | |||
| return this->template as<T>(); | |||
| } else if constexpr (T::KIND == ValueKind::Primitive) { | |||
| return static_cast<const T*>(m_storage.get()); | |||
| } else { | |||
| static_assert(!std::is_same_v<T, T>); | |||
| } | |||
| mgb_assert(m_storage, "empty storage"); | |||
| return static_cast<const T&>(*m_storage); | |||
| } | |||
| const T* operator->() const { return static_cast<const T*>(m_storage.get()); } | |||
| /** | |||
| * \brief reset underlying value to another value | |||
| @@ -409,7 +471,7 @@ public: | |||
| * \param successor new value | |||
| */ | |||
| inline void reset(ValueRef successor) { | |||
| static_assert(T::KIND == ValueKind::Object); | |||
| static_assert(std::is_base_of_v<ObjectValue<T>, T>); | |||
| mgb_assert(m_storage); | |||
| mgb_assert(!m_storage->m_successor); | |||
| if (m_storage->m_watching) { | |||
| @@ -422,25 +484,19 @@ public: | |||
| static inline const TypedValueRef nil; | |||
| friend class ValueRef; | |||
| template <typename, ValueKind> | |||
| friend class ValueImpl; | |||
| friend class Type<T>; | |||
| friend class TypedValueWeakRef<T>; | |||
| }; | |||
| template <typename T> | |||
| class TypedValueWeakRef : public ValueWeakRef { | |||
| private: | |||
| TypedValueWeakRef(const ValueRef& value) : ValueWeakRef(value) {} | |||
| TypedValueWeakRef(const ValueWeakRef& value) : ValueWeakRef(value) {} | |||
| public: | |||
| TypedValueWeakRef(ValueRef value) : ValueWeakRef(value) {} | |||
| TypedValueWeakRef(ValueWeakRef value) : ValueWeakRef(value) {} | |||
| TypedValueRef<T> lock() { | |||
| auto value = ValueWeakRef::lock(); | |||
| if (value) { | |||
| return value.template as_ref<T>(); | |||
| } else { | |||
| return {}; | |||
| } | |||
| } | |||
| TypedValueWeakRef(const TypedValueRef<T>& value) : ValueWeakRef(value) {} | |||
| TypedValueRef<T> lock() { return (TypedValueRef<T>)ValueWeakRef::lock(); } | |||
| }; | |||
| // TODO: add proxy value type, which is meant to be reset in the end | |||
| @@ -509,10 +565,14 @@ inline ValueRefList::ValueRefList(ValueRef item) : m_data(inline_storage()), m_s | |||
| m_data[0] = std::move(item); | |||
| } | |||
| /*class ValueRefList : public SmallVector<ValueRef, 1> { | |||
| public: | |||
| using SmallVector::SmallVector; | |||
| };*/ | |||
| template <typename T> | |||
| template <typename... TArgs> | |||
| TypedValueRef<T> Type<T>::make(TArgs&&... args) const { | |||
| static_assert(std::is_final_v<T>); | |||
| auto storage = LocalPtr<Value>::make<T>(std::forward<TArgs&&>(args)...); | |||
| storage->m_type = this; | |||
| return ValueRef::make(std::move(storage)); | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -123,7 +123,7 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
| } | |||
| } | |||
| inputs.clear(); | |||
| auto input_grads = result.graph.apply( | |||
| auto input_grads = result.graph.apply<TensorPtr>( | |||
| backward_graph_inputs, apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| mgb_assert(input_grads.size() == input_has_grad.size()); | |||
| @@ -177,7 +177,7 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
| } | |||
| } | |||
| inputs.clear(); | |||
| auto input_grads = result.graph.apply( | |||
| auto input_grads = result.graph.apply<TensorPtr>( | |||
| backward_graph_inputs, apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| mgb_assert(input_grads.size() == input_has_grad.size()); | |||
| @@ -244,11 +244,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads = expand_grads( | |||
| bg.output_mask, | |||
| bg.graph.apply( | |||
| bg.graph.apply<TensorPtr>( | |||
| backward_graph_inputs, apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; })); | |||
| auto precomp = obg.precomp.apply( | |||
| auto precomp = obg.precomp.apply<TensorPtr>( | |||
| SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; }); | |||
| ASSERT_EQ(precomp.size(), 2); | |||
| @@ -261,7 +261,7 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
| obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
| auto grads2 = expand_grads( | |||
| obg.input_has_grad, | |||
| obg.backward.apply( | |||
| obg.backward.apply<TensorPtr>( | |||
| backward_inputs, apply_shared_on_physical_tensor, | |||
| [&](auto&& x) { return x; })); | |||