| @@ -1066,7 +1066,31 @@ py::object _adaptive_pool2d_cpp( | |||||
| py::handle inp_hdl, py::handle shape_val_hdl, py::handle pool_mode_hdl) { | py::handle inp_hdl, py::handle shape_val_hdl, py::handle pool_mode_hdl) { | ||||
| py::object shape_hdl = py::reinterpret_borrow<py::object>(shape_val_hdl); | py::object shape_hdl = py::reinterpret_borrow<py::object>(shape_val_hdl); | ||||
| py::list shps(0); | py::list shps(0); | ||||
| if (!PyTuple_Check(shape_val_hdl.ptr())) { | |||||
| auto mode_string = pool_mode_hdl.cast<std::string>(); | |||||
| ::megdnn::param::AdaptivePooling::Mode pool_mode = | |||||
| ::megdnn::param::AdaptivePooling::Mode::MAX; | |||||
| if (mode_string.compare(std::string("AVERAGE")) == 0) { | |||||
| pool_mode = ::megdnn::param::AdaptivePooling::Mode::AVERAGE; | |||||
| } | |||||
| std::shared_ptr<OpDef> op; | |||||
| std::vector<PyObject*> p; | |||||
| auto pool_format = ::megdnn::param::AdaptivePooling::Format::NCHW; | |||||
| auto inp_format = getattr(inp_hdl, "format").cast<std::string>(); | |||||
| if (inp_format == "nhwc") { | |||||
| pool_format = ::megdnn::param::AdaptivePooling::Format::NHWC; | |||||
| } | |||||
| if (TensorWrapper::try_cast(shape_val_hdl.ptr())) { | |||||
| std::vector<int32_t> shp; | |||||
| op = AdaptivePooling::make(pool_mode, pool_format, shp); | |||||
| py::object Op = py::cast(op); | |||||
| p.resize(3); | |||||
| p[0] = Op.ptr(); | |||||
| p[1] = inp_hdl.ptr(); | |||||
| p[2] = shape_val_hdl.ptr(); | |||||
| py::tuple ret = | |||||
| py::reinterpret_steal<py::object>(py_apply(NULL, p.data(), p.size())); | |||||
| return ret[0]; | |||||
| } else if (!PyTuple_Check(shape_val_hdl.ptr())) { | |||||
| shps.append(PyLong_AsLong(shape_val_hdl.ptr())); | shps.append(PyLong_AsLong(shape_val_hdl.ptr())); | ||||
| shps.append(PyLong_AsLong(shape_val_hdl.ptr())); | shps.append(PyLong_AsLong(shape_val_hdl.ptr())); | ||||
| @@ -1078,19 +1102,11 @@ py::object _adaptive_pool2d_cpp( | |||||
| } catch (py::error_already_set& err) { | } catch (py::error_already_set& err) { | ||||
| shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | shape_tuple = py::reinterpret_borrow<py::object>(shape_hdl); | ||||
| } | } | ||||
| auto mode_string = pool_mode_hdl.cast<std::string>(); | |||||
| ::megdnn::param::AdaptivePooling::Mode pool_mode = | |||||
| ::megdnn::param::AdaptivePooling::Mode::MAX; | |||||
| if (mode_string.compare(std::string("AVERAGE")) == 0) { | |||||
| pool_mode = ::megdnn::param::AdaptivePooling::Mode::AVERAGE; | |||||
| } | |||||
| auto [shape, fastpath] = tuple2vector(shape_tuple); | auto [shape, fastpath] = tuple2vector(shape_tuple); | ||||
| fastpath &= enable_fastpath(inp_hdl); | fastpath &= enable_fastpath(inp_hdl); | ||||
| std::shared_ptr<OpDef> op; | |||||
| std::vector<PyObject*> p; | |||||
| py::object shape_tensor; | py::object shape_tensor; | ||||
| op = AdaptivePooling::make( | |||||
| pool_mode, ::megdnn::param::AdaptivePooling::Format::NCHW, shape); | |||||
| op = AdaptivePooling::make(pool_mode, pool_format, shape); | |||||
| if (fastpath) { | if (fastpath) { | ||||
| p.resize(2); | p.resize(2); | ||||
| } else { | } else { | ||||
| @@ -39,6 +39,7 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| const dt_int32* oshp2d = nullptr; | const dt_int32* oshp2d = nullptr; | ||||
| dst_layout.ndim = 4u; | dst_layout.ndim = 4u; | ||||
| bool tshp1n = false; | |||||
| if (nr_inp == 1) { | if (nr_inp == 1) { | ||||
| oshp2d = pool.shape.data(); | oshp2d = pool.shape.data(); | ||||
| } else { | } else { | ||||
| @@ -51,17 +52,18 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( | |||||
| "target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually", | "target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually", | ||||
| tshp.layout.ndim); | tshp.layout.ndim); | ||||
| oshp2d = tshp.value.ptr<dt_int32>(); | oshp2d = tshp.value.ptr<dt_int32>(); | ||||
| tshp1n = tshp.layout.total_nr_elems() == 1; | |||||
| } | } | ||||
| auto param_format = pool.param().format; | auto param_format = pool.param().format; | ||||
| if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { | if (param_format == opr::AdaptivePooling::Param::Format::NCHW) { | ||||
| dst_layout[0] = src.layout[0]; | dst_layout[0] = src.layout[0]; | ||||
| dst_layout[1] = src.layout[1]; | dst_layout[1] = src.layout[1]; | ||||
| dst_layout[2] = oshp2d[0]; | dst_layout[2] = oshp2d[0]; | ||||
| dst_layout[3] = oshp2d[1]; | |||||
| dst_layout[3] = tshp1n ? oshp2d[0] : oshp2d[1]; | |||||
| } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { | } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { | ||||
| dst_layout[0] = src.layout[0]; | dst_layout[0] = src.layout[0]; | ||||
| dst_layout[1] = oshp2d[0]; | dst_layout[1] = oshp2d[0]; | ||||
| dst_layout[2] = oshp2d[1]; | |||||
| dst_layout[2] = tshp1n ? oshp2d[0] : oshp2d[1]; | |||||
| dst_layout[3] = src.layout[3]; | dst_layout[3] = src.layout[3]; | ||||
| } else { | } else { | ||||
| mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); | mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); | ||||
| @@ -83,8 +85,10 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| if (!validated) { | if (!validated) { | ||||
| dst_layout.ndim = src_layout.ndim; | dst_layout.ndim = src_layout.ndim; | ||||
| const dt_int32* oshp2d = nullptr; | const dt_int32* oshp2d = nullptr; | ||||
| bool tshp1n = false; | |||||
| if (inputs.size() == 2) { | if (inputs.size() == 2) { | ||||
| auto&& tshp_nd = inputs[1]; | auto&& tshp_nd = inputs[1]; | ||||
| tshp1n = inputs[1]->layout().total_nr_elems() == 1; | |||||
| oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr<dt_int32>(); | oshp2d = tshp_nd->get_value().proxy_to_default_cpu().ptr<dt_int32>(); | ||||
| } else { | } else { | ||||
| oshp2d = pool.shape.data(); | oshp2d = pool.shape.data(); | ||||
| @@ -93,11 +97,11 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| dst_layout[0] = src_layout[0]; | dst_layout[0] = src_layout[0]; | ||||
| dst_layout[1] = src_layout[1]; | dst_layout[1] = src_layout[1]; | ||||
| dst_layout[2] = oshp2d[0]; | dst_layout[2] = oshp2d[0]; | ||||
| dst_layout[3] = oshp2d[1]; | |||||
| dst_layout[3] = tshp1n ? oshp2d[0] : oshp2d[1]; | |||||
| } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { | } else if (param_format == opr::AdaptivePooling::Param::Format::NHWC) { | ||||
| dst_layout[0] = src_layout[0]; | dst_layout[0] = src_layout[0]; | ||||
| dst_layout[1] = oshp2d[0]; | dst_layout[1] = oshp2d[0]; | ||||
| dst_layout[2] = oshp2d[1]; | |||||
| dst_layout[2] = tshp1n ? oshp2d[0] : oshp2d[1]; | |||||
| dst_layout[3] = src_layout[3]; | dst_layout[3] = src_layout[3]; | ||||
| } else { | } else { | ||||
| mgb_throw( | mgb_throw( | ||||
| @@ -39,22 +39,23 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape( | |||||
| cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); | cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0)); | ||||
| auto src = shpinfo.shape_inp_shp.at(0); | auto src = shpinfo.shape_inp_shp.at(0); | ||||
| mgb_assert( | mgb_assert( | ||||
| src.ndim == 4 && oshp2d.ndim == 2, | |||||
| src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1), | |||||
| "shape mismatch for AdaptivePooling: src=%s, out2d=%s", | "shape mismatch for AdaptivePooling: src=%s, out2d=%s", | ||||
| src.to_string().c_str(), oshp2d.to_string().c_str()); | src.to_string().c_str(), oshp2d.to_string().c_str()); | ||||
| auto param_format = param().format; | auto param_format = param().format; | ||||
| bool tshp1n = oshp2d.ndim == 1; | |||||
| if (param_format == Param::Format::NCHW) { | if (param_format == Param::Format::NCHW) { | ||||
| dest.ndim = 4; | dest.ndim = 4; | ||||
| dest.shape[0] = src.shape[0]; | dest.shape[0] = src.shape[0]; | ||||
| dest.shape[1] = src.shape[1]; | dest.shape[1] = src.shape[1]; | ||||
| dest.shape[2] = oshp2d.shape[0]; | dest.shape[2] = oshp2d.shape[0]; | ||||
| dest.shape[3] = oshp2d.shape[1]; | |||||
| dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; | |||||
| } else if (param_format == Param::Format::NHWC) { | } else if (param_format == Param::Format::NHWC) { | ||||
| dest.ndim = 4; | dest.ndim = 4; | ||||
| dest.shape[0] = src.shape[0]; | dest.shape[0] = src.shape[0]; | ||||
| dest.shape[1] = oshp2d.shape[0]; | dest.shape[1] = oshp2d.shape[0]; | ||||
| dest.shape[2] = oshp2d.shape[1]; | |||||
| dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1]; | |||||
| dest.shape[3] = src.shape[3]; | dest.shape[3] = src.shape[3]; | ||||
| } else { | } else { | ||||
| mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); | mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format"); | ||||