| @@ -29,10 +29,13 @@ def convert_tensor_format(x: Tensor, inplace: bool = True): | |||
| # TODO: use initialization from tensor after fixing format setting | |||
| if x.format != "nhwc": | |||
| if inplace: | |||
| # reset will destroy backward grad | |||
| data = x.numpy().transpose(*pattern) | |||
| x[...] = Tensor(data, format="nhwc") | |||
| else: | |||
| x = Tensor(x.numpy().transpose(*pattern), format="nhwc") | |||
| # use mge interface to maintain grad | |||
| x = F.transpose(x, pattern) | |||
| x.format="nhwc" | |||
| return x | |||
| @@ -245,6 +245,8 @@ def conv2d( | |||
| sparse_type = "dense" if groups == 1 else "group" | |||
| compute_mode = _config._get_actual_op_param(compute_mode, _config.__compute_mode) | |||
| with _config._override(auto_format_convert=False): | |||
| print(compute_mode, inp.shape, inp.format, weight.shape, weight.format) | |||
| op = builtin.Convolution( | |||
| stride_h=stride_h, | |||
| stride_w=stride_w, | |||
| @@ -1,5 +1,6 @@ | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| from megengine import Parameter | |||
| @@ -34,6 +35,7 @@ class GroupNorm(Module): | |||
| def forward(self, x): | |||
| N, C, H, W = x.shape | |||
| format = x.format | |||
| assert C == self.num_channels | |||
| x = x.reshape(N, self.num_groups, -1) | |||
| @@ -44,7 +46,9 @@ class GroupNorm(Module): | |||
| x = x.reshape(N, C, H, W) | |||
| if self.affine: | |||
| x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | |||
| # FIXME(czh): remove this after making it a builtin op. | |||
| if format == "nhwc": | |||
| x = mge.amp.convert_tensor_format(x, inplace=False) | |||
| return x | |||
| def _module_info_string(self) -> str: | |||
| @@ -81,6 +85,7 @@ class InstanceNorm(Module): | |||
| def forward(self, x): | |||
| N, C, H, W = x.shape | |||
| format = x.format | |||
| assert C == self.num_channels | |||
| x = x.reshape(N, C, -1) | |||
| mean = x.mean(axis=2, keepdims=True) | |||
| @@ -90,7 +95,9 @@ class InstanceNorm(Module): | |||
| x = x.reshape(N, C, H, W) | |||
| if self.affine: | |||
| x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) | |||
| # FIXME(czh): remove this after making it a builtin op. | |||
| if format == "nhwc": | |||
| x = mge.amp.convert_tensor_format(x, inplace=False) | |||
| return x | |||
| def _module_info_string(self) -> str: | |||
| @@ -122,7 +122,11 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
| @property | |||
| def format(self) -> str: | |||
| return super().format | |||
| return super().format() | |||
| @format.setter | |||
| def format(self, format): | |||
| super()._set_format(format) | |||
| @property | |||
| def qparams(self): | |||
| @@ -584,6 +584,12 @@ void TensorWrapper::set_module_trace_info(PyObject* obj) { | |||
| module_trace_info_map[m_tensor->data()] = py::reinterpret_borrow<py::object>(obj); | |||
| } | |||
| void TensorWrapper::_set_format(PyObject* dest) { | |||
| auto py_dest = py::reinterpret_borrow<py::object>(dest); | |||
| auto format = py_dest.cast<std::string>(); | |||
| m_tensor->set_format(format); | |||
| } | |||
| void TensorWrapper::_set_name(PyObject* dest) { | |||
| auto py_dest = py::reinterpret_borrow<py::object>(dest); | |||
| auto name = py_dest.cast<std::string>(); | |||
| @@ -812,7 +818,7 @@ void init_tensor(py::module m) { | |||
| .def_getset<&TensorWrapper::shape>("shape") | |||
| .def_getset<&TensorWrapper::dtype>("dtype") | |||
| .def_getset<&TensorWrapper::device>("device") | |||
| .def_getset<&TensorWrapper::format>("format") | |||
| .def<&TensorWrapper::format>("format") | |||
| .def<&TensorWrapper::reset>("_reset") | |||
| .def<&TensorWrapper::isscalar>("_isscalar") | |||
| .def<&TensorWrapper::detach>("detach") | |||
| @@ -820,6 +826,7 @@ void init_tensor(py::module m) { | |||
| .def<&TensorWrapper::_dev_tensor>("_dev_tensor") | |||
| .def<&TensorWrapper::_drop>("_drop") | |||
| .def<&TensorWrapper::_detail>("_detail") | |||
| .def<&TensorWrapper::_set_format>("_set_format") | |||
| .def<&TensorWrapper::_set_name>("_set_name") | |||
| .def<&TensorWrapper::_watch>("_watch") | |||
| .def<&TensorWrapper::_var>("var") | |||
| @@ -59,6 +59,11 @@ public: | |||
| return *shape; | |||
| } | |||
| inline Format format() { return *data().format(); } | |||
| inline void set_format(std::string format) { | |||
| if (!format.empty()) { | |||
| m_data = imperative::apply(SetFormat(format), m_data)[0]; | |||
| } | |||
| } | |||
| inline HostValue::ref_t numpy() { return data().numpy(); } | |||
| inline void reset(ValueRef value) { | |||
| m_data = value; | |||
| @@ -130,6 +135,7 @@ public: | |||
| PyObject* copied(); | |||
| PyObject* module_trace_info(); | |||
| void set_module_trace_info(PyObject*); | |||
| void _set_format(PyObject*); | |||
| void _set_name(PyObject*); | |||
| PyObject* _detail(); | |||
| PyObject* _var(); | |||
| @@ -31,6 +31,9 @@ def test_basic(): | |||
| b[...] = tensor(data, format="nchw") | |||
| assert b.format == "nchw" | |||
| # set tensor's format | |||
| b.format = "nhwc" | |||
| assert b.format == "nhwc" | |||
| def _compare_nchw_nhwc(data, func, is_symbolic=None): | |||
| x1 = tensor(data) | |||
| @@ -105,9 +105,16 @@ std::string IsScalar::to_string() const { | |||
| return "IsScalar"; | |||
| } | |||
| std::string GetFormat::to_string() const { | |||
| return "GetFormat{}"; | |||
| } | |||
| std::string SetFormat::to_string() const { | |||
| return ssprintf("SetFormat{format=%s}", m_format.to_string().c_str()); | |||
| } | |||
| std::string GetVarVal::to_string() const { | |||
| return "GetVarVal"; | |||
| } | |||
| } // namespace imperative | |||
| } // namespace mgb | |||
| @@ -57,15 +57,15 @@ inline ValueRefList FormatTransformation::unwrap_inputs( | |||
| } | |||
| inline ValueRef FormatTransformation::wrap_output( | |||
| const ValueRef& output, FT type) const { | |||
| return m_value_type.make(output, type); | |||
| const ValueRef& output, Format format) const { | |||
| return m_value_type.make(output, format); | |||
| } | |||
| inline ValueRefList FormatTransformation::wrap_outputs( | |||
| const ValueRefList& outputs, FT type) const { | |||
| const ValueRefList& outputs, Format format) const { | |||
| ValueRefList wrapped_outputs(outputs.size()); | |||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||
| wrapped_outputs[i] = wrap_output(outputs[i], type); | |||
| wrapped_outputs[i] = wrap_output(outputs[i], format); | |||
| } | |||
| return wrapped_outputs; | |||
| } | |||
| @@ -241,7 +241,7 @@ ValueRefList subtensor_rule( | |||
| if (!(auto_convert && src.format() == FT::NHWC)) { | |||
| return {t.wrap_output( | |||
| imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||
| src.format().type())}; | |||
| src.format())}; | |||
| } | |||
| auto nhwc_items = convert_nchw2nhwc_idx_items(op.items); | |||
| auto outputs = imperative::apply( | |||
| @@ -264,7 +264,7 @@ ValueRefList setsubtensor_rule( | |||
| if (!(auto_convert && src.format() == FT::NHWC)) { | |||
| return {t.wrap_output( | |||
| imperative::apply(op, t.unwrap_inputs(inputs))[0], | |||
| src.format().type())}; | |||
| src.format())}; | |||
| } | |||
| // value has been broadcasted to src's fake NCHW shape. | |||
| auto& value = inputs[1].cast(t.value_type()); | |||
| @@ -330,7 +330,7 @@ ValueRefList identity_rule_helper( | |||
| // mgb_assert(inputs.size() == 1); | |||
| auto& src = inputs[0].cast(t.value_type()); | |||
| return t.wrap_outputs( | |||
| imperative::apply(op, t.unwrap_inputs(inputs)), src.format().type()); | |||
| imperative::apply(op, t.unwrap_inputs(inputs)), src.format()); | |||
| } | |||
| ValueRefList batchnorm_rule( | |||
| @@ -467,7 +467,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| } | |||
| } else if (auto* create_tensor = op.as<CreateTensor>()) { | |||
| auto format = create_tensor->format(); | |||
| return {wrap_output(imperative::apply(op, inputs)[0], format.type())}; | |||
| return {wrap_output(imperative::apply(op, inputs)[0], format)}; | |||
| } else if (auto* get_attr = op.as<GetAttr>()) { | |||
| auto&& input = inputs.item(); | |||
| if (!input.is(m_value_type)) { | |||
| @@ -500,12 +500,16 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| op.to_string().c_str(), inputs[0].to_string().c_str()); | |||
| return {FormatValue::make(FT::DEFAULT)}; | |||
| } | |||
| } else if (auto* _op = op.as<SetFormat>()) { | |||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
| mgb_assert(inp_ref, "Cannot set format for non-format Tensor."); | |||
| return {m_value_type.make(inp_ref->value(), _op->format())}; | |||
| } else if (op.is<Operator::IdentityLike>()) { | |||
| auto&& inp_ref = inputs[0].as_ref(m_value_type); | |||
| if (inp_ref) { | |||
| auto&& format = inp_ref->format(); | |||
| return wrap_outputs( | |||
| imperative::apply(op, unwrap_inputs(inputs)), format.type()); | |||
| imperative::apply(op, unwrap_inputs(inputs)), format); | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue input for IdentityLike op: %s, %s", | |||
| @@ -521,13 +525,13 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| GenericFunction new_callback = | |||
| [this, callback, format](Span<ValueRef> inputs_) -> ValueRefList { | |||
| auto wrapped_inputs = SmallVector<ValueRef>{ | |||
| this->value_type().make(inputs_.item(), format.type())}; | |||
| this->value_type().make(inputs_.item(), format)}; | |||
| auto ret = callback(wrapped_inputs); | |||
| return ret; | |||
| }; | |||
| auto&& outputs = imperative::apply( | |||
| op, inp_ref->value(), FunctionValue::make(new_callback)); | |||
| return wrap_outputs(outputs, format.type()); | |||
| return wrap_outputs(outputs, format); | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue input for AttachGrad op: %s, %s", | |||
| @@ -549,7 +553,7 @@ ValueRefList FormatTransformation::apply_transformation( | |||
| for (size_t i = 0; i < nr_outputs; ++i) { | |||
| if (auto output_ref = outputs_[i].as_ref(m_value_type)) { | |||
| wrapped_outputs[i] = | |||
| m_value_type.make(outputs[i], output_ref->format().type()); | |||
| m_value_type.make(outputs[i], output_ref->format()); | |||
| } else { | |||
| mgb_log_warn( | |||
| "Not FormattedTensorValue outputs for SetGrad op: %s, %s", | |||
| @@ -164,7 +164,19 @@ public: | |||
| class GetFormat final : public OperatorImpl<GetFormat, Operator::GetAttrLike> { | |||
| public: | |||
| std::string to_string() const override { return "GetFormat{}"; } | |||
| std::string to_string() const override; | |||
| }; | |||
| class SetFormat final : public OperatorImpl<SetFormat, Operator::IdentityLike> { | |||
| private: | |||
| Format m_format; | |||
| public: | |||
| SetFormat(std::string format) : m_format(format) {} | |||
| Format format() const { return m_format; } | |||
| std::string to_string() const override; | |||
| }; | |||
| class GetVarVal final : public OperatorImpl<GetVarVal, Operator::GetAttrLike> { | |||
| @@ -26,6 +26,8 @@ public: | |||
| const Format& format() const { return m_format; } | |||
| void set_format(Format format) { m_format = format; } | |||
| void clear() override { | |||
| m_value = {}; | |||
| m_format = {}; | |||
| @@ -65,10 +67,10 @@ public: | |||
| inline ValueRef unwrap_input(const ValueRef& input) const; | |||
| inline ValueRefList unwrap_inputs(const Span<ValueRef>& inputs) const; | |||
| inline ValueRef wrap_output( | |||
| const ValueRef& output, Format::Type type = Format::Type::DEFAULT) const; | |||
| const ValueRef& output, Format format = Format::Type::DEFAULT) const; | |||
| inline ValueRefList wrap_outputs( | |||
| const ValueRefList& outputs, | |||
| Format::Type type = Format::Type::DEFAULT) const; | |||
| Format format = Format::Type::DEFAULT) const; | |||
| TypedValueRef<FormattedTensorValue> as( | |||
| const FormattedTensorValue&, const Format::Type& target) const; | |||