GitOrigin-RevId: 6edd577a70
tags/v1.10.0
| @@ -50,8 +50,6 @@ class autocast: | |||
| self._origin_enabled = None | |||
| self._origin_high = None | |||
| self._origin_low = None | |||
| self._origin_compute_mode = None | |||
| self._origin_configs = None | |||
| def __enter__(self): | |||
| @@ -75,7 +73,7 @@ class autocast: | |||
| amp._set_amp_high_prec_dtype(self._origin_high) | |||
| amp._set_amp_low_prec_dtype(self._origin_low) | |||
| _config._reset_execution_config(*self._origin_compute_mode) | |||
| _config._reset_execution_config(*self._origin_configs) | |||
| def __call__(self, func): | |||
| @functools.wraps(func) | |||
| @@ -15,11 +15,14 @@ from ..core import _config | |||
| def _is_nchw_format(param: Tensor): | |||
| # TODO: use better condition | |||
| return (len(param.shape) == 4 or len(param.shape) == 5) and param.format != "nhwc" | |||
| return (param.ndim == 4 or param.ndim == 5) and param.format != "nhwc" | |||
| def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
| """Convert NCHW Tensor to NHWC Tensor.""" | |||
| if not _is_nchw_format(x): | |||
| return x | |||
| if x.ndim == 4: | |||
| pattern = (0, 2, 3, 1) | |||
| elif x.ndim == 5: | |||
| @@ -29,8 +32,9 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
| # TODO: use initialization from tensor after fixing format setting | |||
| if x.format != "nhwc": | |||
| if inplace: | |||
| # reset will destroy backward grad | |||
| # hostvalue should still be valid, so no d2h cost. | |||
| data = x.numpy().transpose(*pattern) | |||
| # reset will destroy existed backward grad | |||
| x[...] = Tensor(data, format="nhwc") | |||
| else: | |||
| # use mge interface to maintain grad | |||
| @@ -45,7 +49,5 @@ def convert_module_format(module: Module, inplace: bool = True): | |||
| module = deepcopy(module) | |||
| for name, param in module.named_tensors(): | |||
| if _is_nchw_format(param): | |||
| # hostvalue should still be valid, so no d2h cost. | |||
| convert_tensor_format(param, inplace=True) | |||
| convert_tensor_format(param, inplace=True) | |||
| return module | |||
| @@ -64,9 +64,7 @@ class Grad: | |||
| continue | |||
| grad.suppress() | |||
| print("before backward") | |||
| self._impl.backward(ys, dys) | |||
| print("after backward") | |||
| for grad in group: | |||
| if grad is self: | |||
| @@ -245,8 +245,6 @@ def conv2d( | |||
| sparse_type = "dense" if groups == 1 else "group" | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| with _config._override(auto_format_convert=False): | |||
| print(compute_mode, inp.shape, inp.format, weight.shape, weight.format) | |||
| op = builtin.Convolution( | |||
| stride_h=stride_h, | |||
| stride_w=stride_w, | |||
| @@ -320,7 +320,7 @@ py::object _Const(py::handle value, py::handle dtype, py::handle device) { | |||
| } | |||
| } | |||
| py::object device_obj = device2obj(device, true); | |||
| py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none()); | |||
| py::tuple tup = py::make_tuple(val, dtype, device_obj, true, false, py::none(), py::none()); | |||
| return TensorWrapper::make(py_tensor_type, tup.ptr(), nullptr); | |||
| } | |||
| @@ -35,6 +35,7 @@ def test_basic(): | |||
| b.format = "nhwc" | |||
| assert b.format == "nhwc" | |||
| def _compare_nchw_nhwc(data, func, is_symbolic=None): | |||
| x1 = tensor(data) | |||
| x2 = tensor(data.transpose(0, 2, 3, 1), format="nhwc") | |||
| @@ -335,21 +336,42 @@ def _compare_backward(inps, model, is_symbolic=None): | |||
| gm = GradManager().attach(model.parameters()) | |||
| with gm: | |||
| rst = func(*inps) | |||
| gm.backward(rst) | |||
| expected_grads = [param.grad for param in model.parameters()] | |||
| with mge.amp.autocast(): | |||
| rst = func(*inps) | |||
| gm.backward(rst) | |||
| expected_grads = [param.grad.numpy() for param in gm.attached_tensors()] | |||
| for param in gm.attached_tensors(): | |||
| param.grad = None | |||
| inps = [mge.amp.convert_tensor_format(inp) for inp in inps] | |||
| model = mge.amp.convert_module_format(model) | |||
| gm = GradManager().attach(model.parameters()) | |||
| with gm: | |||
| rst = func(*inps) | |||
| gm.backward(rst) | |||
| actual_grads = [param.grad for param in model.parameters()] | |||
| with mge.amp.autocast(): | |||
| rst = func(*inps) | |||
| gm.backward(rst) | |||
| actual_grads = [param.grad.numpy() for param in gm.attached_tensors()] | |||
| for expected, actual in zip(expected_grads, actual_grads): | |||
| # print(param.grad) | |||
| np.testing.assert_equal(expected.numpy(), actual.numpy()) | |||
| assert expected is not None | |||
| assert actual is not None | |||
| np.testing.assert_almost_equal(expected, actual, decimal=5) | |||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||
| def test_backward_basic(is_symbolic): | |||
| class Net(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.w = mge.Parameter([[2.0], [4.0], [6.0]]) | |||
| self.b = mge.Parameter(-1.0) | |||
| def forward(self, inp): | |||
| return F.matmul(inp, self.w) + self.b | |||
| inp = mge.tensor([1.0, 3.0, 5.0]).reshape(1, 3) | |||
| _compare_backward([inp], Net(), is_symbolic) | |||
| @pytest.mark.parametrize("is_symbolic", [None]) | |||
| @@ -379,14 +401,15 @@ def test_backward_groupconv2d_bn(is_symbolic): | |||
| class Net(M.Module): | |||
| def __init__(self): | |||
| super().__init__() | |||
| self.conv = M.Conv2d(2, 2, 1, groups=2) | |||
| self.bn = M.BatchNorm2d(2) | |||
| self.conv0 = M.Conv2d(32, 256, 3, groups=32, stride=2) | |||
| self.conv1 = M.Conv2d(256, 2048, 3, groups=32, stride=2) | |||
| # self.bn = M.BatchNorm2d(2048) | |||
| def forward(self, inp): | |||
| # test manually convert to NHWC, usually used in detection head | |||
| return self.bn(self.conv(inp)) | |||
| return self.conv1(self.conv0(inp)) | |||
| inp = mge.tensor(np.arange(0, 24).reshape((1, 2, 3, 4))) | |||
| inp = mge.tensor(np.ones(shape=(32, 32, 56, 56)).astype("float32")) | |||
| _compare_backward([inp], Net(), is_symbolic) | |||
| # def func(x, w, b, bn_w, bn_b): | |||
| # x = F.conv2d(x, w, b, groups=2) | |||
| @@ -260,6 +260,7 @@ void ChannelImpl::dispatch_default_cpu( | |||
| CompNode output_cn; | |||
| { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD dispatch_default_cpu"); | |||
| for (auto&& info : input_infos) { | |||
| auto input_cn = info->desc.comp_node; | |||
| if (!output_cn.valid()) { | |||
| @@ -277,6 +278,7 @@ void ChannelImpl::dispatch_default_cpu( | |||
| input_tensornds.emplace_back(info->h_value.proxy_to_default_cpu()); | |||
| } | |||
| } | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD dispatch_default_cpu"); | |||
| } | |||
| SmallVector<DeviceTensorND> output_tensornds; | |||
| @@ -530,7 +532,9 @@ void ChannelImpl::sync() { | |||
| void ChannelImpl::sync_impl() { | |||
| m_worker.wait_all_task_finish(); | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD sync_impl"); | |||
| check_worker_exc_unsafe(); | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD sync_impl"); | |||
| } | |||
| void ChannelImpl::close() { | |||
| @@ -689,6 +693,7 @@ ChannelImpl::~ChannelImpl() { | |||
| void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
| auto& state = get_worker_state(); | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD produce_tensor"); | |||
| m_dtr.update_used_time(dest); | |||
| MGB_RECORD_EVENT( | |||
| TensorProduceEvent, dest->id, ptr->layout(), ptr->comp_node(), | |||
| @@ -715,16 +720,19 @@ void ChannelImpl::produce_tensor(TensorInfo* dest, TensorPtr ptr) { | |||
| m_dtr.insert_candidate(dest); | |||
| } | |||
| notify_tensor_unsafe(dest); | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD produce_tensor"); | |||
| } | |||
| void ChannelImpl::release_tensor(TensorInfo* dest) { | |||
| MGB_RECORD_EVENT(TensorReleaseEvent, dest->id); | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD release_tensor"); | |||
| dest->ptr.reset(); | |||
| auto& state = get_worker_state(); | |||
| if (dest->size_exceeds_thd(state.options.dtr_evictee_minimum_size)) { | |||
| m_dtr.erase_candidate(dest); | |||
| } | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD release_tensor"); | |||
| } | |||
| void ChannelImpl::regenerate(TensorInfo* dest) { | |||
| @@ -1000,6 +1008,7 @@ bool ChannelImpl::check_available() { | |||
| TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
| std::unique_lock<decltype(m_mutex)> lock(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor"); | |||
| mgb_assert(!m_waitee, "duplicate waitee"); | |||
| m_waitee = info; | |||
| m_waitee_id = Profiler::next_id(); | |||
| @@ -1010,6 +1019,7 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
| if (require_host && !host_available()) { | |||
| // avoid dead lock | |||
| lock.unlock(); | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor unlock"); | |||
| if (Profiler::is_profiling()) { | |||
| m_worker.add_task( | |||
| {Profiler::next_id(), GetValue{info}, | |||
| @@ -1021,18 +1031,21 @@ TensorPtr ChannelImpl::wait_tensor(TensorInfo* info, TensorProp prop) { | |||
| }); | |||
| } | |||
| lock.lock(); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD wait_tensor lock"); | |||
| wait_host = true; | |||
| } | |||
| m_cv.wait(lock, [&]() { | |||
| check_worker_exc_unsafe(); | |||
| return require_host ? host_available() : static_cast<bool>(info->ptr); | |||
| }); | |||
| //mgb_log_warn("after cv wait"); | |||
| MGB_RECORD_EVENT(TensorWaitPropFinishEvent, info->id, m_waitee_id, prop); | |||
| m_waitee = nullptr; | |||
| if (wait_host) { | |||
| auto err = info->ptr->comp_node().check_async_error(); | |||
| mgb_assert(!err, "%s", err->what()); | |||
| } | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD wait_tensor"); | |||
| return info->ptr; | |||
| } | |||
| @@ -1040,6 +1053,7 @@ void ChannelImpl::notify_tensor_unsafe(TensorInfo* info) { | |||
| if (info == m_waitee) { | |||
| MGB_RECORD_EVENT(TensorNotifyPropEvent, info->id); | |||
| m_cv.notify_all(); | |||
| //mgb_log_warn("cv notify_all"); | |||
| } | |||
| } | |||
| @@ -1102,6 +1116,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
| using namespace ranges::views; | |||
| auto& state = get_worker_state(); | |||
| auto& options = state.options; | |||
| //mgb_log_warn("process_one_task %s", to_string<Command>(icmd).c_str()); | |||
| // TODO: remove std::visit for support osx 10.12 | |||
| auto cmd_visitor = [&](const auto& cmd) { | |||
| using T = std::decay_t<decltype(cmd)>; | |||
| @@ -1123,9 +1138,11 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
| for (auto& i : cmd.inputs) { | |||
| if (mgb_unlikely(i->invalid)) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD ApplyOp"); | |||
| for (auto& i : cmd.outputs) { | |||
| i->invalid = true; | |||
| } | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD ApplyOp"); | |||
| return; | |||
| } | |||
| } | |||
| @@ -1210,8 +1227,10 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
| } | |||
| cmd.dest->ptr->fetch_value(); | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD GetValue"); | |||
| notify_tensor_unsafe(cmd.dest); | |||
| imperative_log_profile_end("GetValue"); | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD GetValue"); | |||
| } else if constexpr (std::is_same_v<T, Drop>) { | |||
| if (cmd.dest->invalid) | |||
| return; | |||
| @@ -1271,6 +1290,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
| cmd_visitor(cmd); | |||
| } catch (...) { | |||
| MGB_LOCK_GUARD(m_mutex); | |||
| //mgb_log_warn(">>> MGB_LOCK_GUARD catch exception"); | |||
| if constexpr (std::is_same_v<T, ApplyOp>) { | |||
| for (auto oup : cmd.outputs) { | |||
| oup->invalid = true; | |||
| @@ -1283,6 +1303,7 @@ void ChannelImpl::process_one_task(Command& icmd) { | |||
| if (m_waitee) { | |||
| notify_tensor_unsafe(m_waitee); | |||
| } | |||
| //mgb_log_warn("<<< MGB_LOCK_GUARD catch exception"); | |||
| } | |||
| }, | |||
| icmd.data); | |||
| @@ -33,9 +33,8 @@ TypedValueRef<FormattedTensorValue> FormatTransformation::to( | |||
| tensor.format().to_string().c_str(), | |||
| Format(target).to_string().c_str()); | |||
| } | |||
| auto output = imperative::apply( | |||
| *Dimshuffle::make(pattern, scope), | |||
| SmallVector<ValueRef>{tensor.value()})[0]; | |||
| auto output = | |||
| imperative::apply(*Dimshuffle::make(pattern, scope), {tensor.value()})[0]; | |||
| return m_value_type.make(output, target); | |||
| } | |||
| @@ -90,6 +89,27 @@ ValueShape convert_nhwc2nchw_shape(const ValueShape& shape) { | |||
| } | |||
| } | |||
| std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape) { | |||
| auto out = std::vector<int32_t>(shape); | |||
| if (shape.size() == 4) { | |||
| out[1] = shape[2]; | |||
| out[2] = shape[3]; | |||
| out[3] = shape[1]; | |||
| return out; | |||
| } else if (shape.size() == 5) { | |||
| // GIOHW -> GIHWO | |||
| out[2] = shape[3]; | |||
| out[3] = shape[4]; | |||
| out[4] = shape[2]; | |||
| return out; | |||
| } else { | |||
| mgb_throw( | |||
| MegBrainError, | |||
| "Unsupported shape ndim %u in convert NCHW shape to NHWC.", | |||
| shape.size()); | |||
| } | |||
| } | |||
| using FormatRule = std::function<ValueRefList( | |||
| const OpDef&, Span<ValueRef>&, const bool&, const FormatTransformation&)>; | |||
| static std::unordered_map<Typeinfo*, FormatRule> format_rules; | |||
| @@ -156,22 +176,38 @@ ValueRef convert_nchw2nhwc_tensornd(const HostTensorND& shape) { | |||
| ValueRefList reshape_rule( | |||
| const Reshape& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
| const FormatTransformation& t) { | |||
| mgb_assert(inputs.size() == 2); | |||
| mgb_assert(inputs.size() >= 1); | |||
| auto& src = inputs[0].cast(t.value_type()); | |||
| if (auto_convert && src.format() == FT::NHWC) { | |||
| auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
| if (shape.layout().total_nr_elems() == 4) { | |||
| // output is still NHWC format | |||
| auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
| auto outputs = imperative::apply( | |||
| op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
| return t.wrap_outputs(outputs, FT::NHWC); | |||
| } else { | |||
| // will not maintain src's format | |||
| auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
| auto outputs = imperative::apply( | |||
| op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
| return t.wrap_outputs(outputs); | |||
| if (inputs.size() == 1) { | |||
| if (op.shape.size() == 4) { | |||
| // output is still NHWC format | |||
| auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); | |||
| auto outputs = imperative::apply( | |||
| *Reshape::make(op.axis, nhwc_shape), {t.unwrap_input(inputs[0])}); | |||
| return t.wrap_outputs(outputs, FT::NHWC); | |||
| } else { | |||
| // will not maintain src's format | |||
| auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
| auto outputs = imperative::apply(op, {nchw_src}); | |||
| return t.wrap_outputs(outputs); | |||
| } | |||
| } else if (inputs.size() == 2) { | |||
| auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
| if (shape.layout().total_nr_elems() == 4) { | |||
| // output is still NHWC format | |||
| auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
| auto outputs = imperative::apply( | |||
| op, | |||
| SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
| return t.wrap_outputs(outputs, FT::NHWC); | |||
| } else { | |||
| // will not maintain src's format | |||
| auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
| auto outputs = imperative::apply( | |||
| op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
| return t.wrap_outputs(outputs); | |||
| } | |||
| } | |||
| } | |||
| return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | |||
| @@ -180,22 +216,38 @@ ValueRefList reshape_rule( | |||
| ValueRefList broadcast_rule( | |||
| const Broadcast& op, Span<ValueRef>& inputs, const bool& auto_convert, | |||
| const FormatTransformation& t) { | |||
| mgb_assert(inputs.size() == 2); | |||
| mgb_assert(inputs.size() >= 1); | |||
| auto& src = inputs[0].cast(t.value_type()); | |||
| if (auto_convert && src.format() == FT::NHWC) { | |||
| auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
| if (shape.layout().total_nr_elems() == 4) { | |||
| // output is still NHWC format | |||
| auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
| auto outputs = imperative::apply( | |||
| op, SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
| return t.wrap_outputs(outputs, FT::NHWC); | |||
| } else { | |||
| // will not maintain src's format | |||
| auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
| auto outputs = imperative::apply( | |||
| op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
| return t.wrap_outputs(outputs); | |||
| if (inputs.size() == 1) { | |||
| if (op.shape.size() == 4) { | |||
| // output is still NHWC format | |||
| auto nhwc_shape = convert_nchw2nhwc_vector(op.shape); | |||
| auto outputs = imperative::apply( | |||
| *Broadcast::make(nhwc_shape), {t.unwrap_input(inputs[0])}); | |||
| return t.wrap_outputs(outputs, FT::NHWC); | |||
| } else { | |||
| // will not maintain src's format | |||
| auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
| auto outputs = imperative::apply(op, {nchw_src}); | |||
| return t.wrap_outputs(outputs); | |||
| } | |||
| } else if (inputs.size() == 2) { | |||
| auto shape = t.unwrap_input(inputs[1]).numpy()->as_nd(); | |||
| if (shape.layout().total_nr_elems() == 4) { | |||
| // output is still NHWC format | |||
| auto nhwc_shape = convert_nchw2nhwc_tensornd(shape); | |||
| auto outputs = imperative::apply( | |||
| op, | |||
| SmallVector<ValueRef>{t.unwrap_input(inputs[0]), nhwc_shape}); | |||
| return t.wrap_outputs(outputs, FT::NHWC); | |||
| } else { | |||
| // will not maintain src's format | |||
| auto nchw_src = t.to(src, FT::NCHW, op.scope())->value(); | |||
| auto outputs = imperative::apply( | |||
| op, SmallVector<ValueRef>{nchw_src, t.unwrap_input(inputs[1])}); | |||
| return t.wrap_outputs(outputs); | |||
| } | |||
| } | |||
| } | |||
| return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs))); | |||
| @@ -240,8 +292,7 @@ ValueRefList subtensor_rule( | |||
| // only support NHWC2NCHW convert, otherwise maintain src's format | |||
| if (!(auto_convert && src.format() == FT::NHWC)) { | |||
| return {t.wrap_output( | |||
| imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||
| src.format())}; | |||
| imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())}; | |||
| } | |||
| auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | |||
| auto outputs = imperative::apply( | |||
| @@ -263,8 +314,7 @@ ValueRefList setsubtensor_rule( | |||
| // only support NHWC2NCHW convert, otherwise maintain src's format | |||
| if (!(auto_convert && src.format() == FT::NHWC)) { | |||
| return {t.wrap_output( | |||
| imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||
| src.format())}; | |||
| imperative::apply(op, t.unwrap_inputs(inputs))[0], src.format())}; | |||
| } | |||
| // value has been broadcasted to src's fake NCHW shape. | |||
| auto& value = inputs[1].cast(t.value_type()); | |||
| @@ -329,8 +379,7 @@ ValueRefList identity_rule_helper( | |||
| const OpDef& op, const Span<ValueRef>& inputs, const FormatTransformation& t) { | |||
| // mgb_assert(inputs.size() == 1); | |||
| auto& src = inputs[0].cast(t.value_type()); | |||
| return t.wrap_outputs( | |||
| imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||
| return t.wrap_outputs(imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||
| } | |||
| ValueRefList batchnorm_rule( | |||
| @@ -457,6 +506,7 @@ struct FormatRuleRegistry { | |||
| ValueRefList FormatTransformation::apply_transformation( | |||
| const Operator& op, Span<ValueRef> inputs) { | |||
| //mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str()); | |||
| if (auto* apply_op = op.as<ApplyOp>()) { | |||
| // all inputs should be FormattedTensorValue | |||
| auto iter = format_rules.find(apply_op->op().dyn_typeinfo()); | |||
| @@ -485,7 +535,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| } | |||
| case GetAttr::Value: { | |||
| auto nchw_src = unwrap_input(to(src, FT::NCHW, "")); | |||
| return imperative::apply(op, SmallVector<ValueRef>{nchw_src}); | |||
| return imperative::apply(op, {nchw_src}); | |||
| } | |||
| default: | |||
| return imperative::apply(op, unwrap_inputs(inputs)); | |||
| @@ -508,8 +558,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
| if (inp_ref) { | |||
| auto&& format = inp_ref->format(); | |||
| return wrap_outputs( | |||
| imperative::apply(op, unwrap_inputs(inputs)), format); | |||
| return wrap_outputs(imperative::apply(op, unwrap_inputs(inputs)), format); | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue input for IdentityLike op: %s, %s", | |||
| @@ -522,6 +571,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| auto format = inp_ref->format(); | |||
| GenericFunction callback = | |||
| (GenericFunction&)inputs[1].cast<FunctionValue>(); | |||
| // make param grads as FormattedTensor | |||
| GenericFunction new_callback = | |||
| [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||
| auto wrapped_inputs = SmallVector<ValueRef>{ | |||
| @@ -531,6 +581,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| }; | |||
| auto&& outputs = imperative::apply( | |||
| op, inp_ref->value(), FunctionValue::make(new_callback)); | |||
| // make params(GradValue) as FormattedTensor | |||
| return wrap_outputs(outputs, format); | |||
| } else { | |||
| mgb_log_warn( | |||
| @@ -539,6 +590,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| return imperative::apply(op, inputs); | |||
| } | |||
| } else if (auto* set_grad = op.as<SetGrad>()) { | |||
| // make grads in Function backward as FormattedTensor | |||
| size_t nr_inputs = set_grad->nr_inputs(); | |||
| size_t nr_outputs = inputs.size() - nr_inputs; | |||
| Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; | |||
| @@ -377,8 +377,6 @@ public: | |||
| SetGrad(GenericFunction grad_fn, size_t nr_inputs) | |||
| : m_grad_fn(grad_fn), m_nr_inputs(nr_inputs) {} | |||
| std::shared_ptr<GradKey> key() const { return m_key; } | |||
| GenericFunction grad_fn() const { return m_grad_fn; } | |||
| size_t nr_inputs() const { return m_nr_inputs; } | |||