GitOrigin-RevId: dfb401a945
tags/v1.6.0
| @@ -18,6 +18,7 @@ import megengine.amp as amp | |||||
| import megengine.core.ops.builtin as builtin | import megengine.core.ops.builtin as builtin | ||||
| import megengine.core.tensor.dtype as dtype | import megengine.core.tensor.dtype as dtype | ||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.jit as jit | |||||
| from megengine import Parameter, Tensor, is_cuda_available, tensor | from megengine import Parameter, Tensor, is_cuda_available, tensor | ||||
| from megengine.core._trace_option import use_symbolic_shape | from megengine.core._trace_option import use_symbolic_shape | ||||
| from megengine.core.autodiff.grad import Grad | from megengine.core.autodiff.grad import Grad | ||||
| @@ -859,6 +860,35 @@ def test_condtake(): | |||||
| np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) | ||||
| # @pytest.mark.parametrize("is_symbolic", [None, False, True]) | |||||
| def test_condtake(is_symbolic=None): | |||||
| shapes = [ | |||||
| (3, 3, 3), | |||||
| (0,), | |||||
| (3, 0, 3), | |||||
| ] | |||||
| def fn(mask, data): | |||||
| return F.cond_take(mask, data) | |||||
| if is_symbolic is not None: | |||||
| fn = jit.trace(symbolic=is_symbolic)(fn) | |||||
| for shp in shapes: | |||||
| x_np = np.random.randn(*shp).astype("float32") | |||||
| mask_np = x_np > 0 | |||||
| x = tensor(x_np) | |||||
| mask = tensor(mask_np) | |||||
| ref_out = x_np[mask_np] | |||||
| ref_idx = mask_np.flatten().nonzero()[0] | |||||
| for i in range(3): | |||||
| out, idx = fn(mask, x) | |||||
| np.testing.assert_equal(out.numpy(), ref_out) | |||||
| np.testing.assert_equal(idx.numpy(), ref_idx) | |||||
| if is_symbolic is None: | |||||
| break | |||||
| def test_condtake_is_same(): | def test_condtake_is_same(): | ||||
| op1 = builtin.CondTake() | op1 = builtin.CondTake() | ||||
| op2 = builtin.CondTake() | op2 = builtin.CondTake() | ||||
| @@ -45,25 +45,30 @@ SmallVector<TensorPtr> apply_on_physical_tensor( | |||||
| auto&& inp = inputs[0]; | auto&& inp = inputs[0]; | ||||
| auto&& msk = inputs[1]; | auto&& msk = inputs[1]; | ||||
| SmallVector<TensorPtr> out; | |||||
| mgb_assert(inp->layout().eq_shape(msk->layout()), | mgb_assert(inp->layout().eq_shape(msk->layout()), | ||||
| "input shape does not match mask shape"); | "input shape does not match mask shape"); | ||||
| mgb_assert(msk->get_value().dtype().enumv() == DTypeEnum::Bool, | mgb_assert(msk->get_value().dtype().enumv() == DTypeEnum::Bool, | ||||
| "mask dtype must be bool"); | "mask dtype must be bool"); | ||||
| DnnOprCaller<megdnn::CondTake> dnn_op(inp->comp_node()); | |||||
| dnn_op.op->param().val = 1; | |||||
| TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())}, | |||||
| dtype::Byte()); | |||||
| auto dnn_workspace = dnn_op.create_workspace(m_layout); | |||||
| MegDNNDynOutMallocImpl<2> policy{inp->comp_node()}; | MegDNNDynOutMallocImpl<2> policy{inp->comp_node()}; | ||||
| dnn_op.op->exec(inp->dev_tensor().as_megdnn(), | |||||
| msk->dev_tensor().as_megdnn(), | |||||
| dnn_workspace, | |||||
| &policy); | |||||
| SmallVector<TensorPtr> out; | |||||
| if (inp->layout().is_empty()) { | |||||
| // empty tensor | |||||
| policy.alloc_output(0, inp->layout().dtype, {0}, nullptr); | |||||
| policy.alloc_output(1, dtype::Int32(), {0}, nullptr); | |||||
| } else { | |||||
| DnnOprCaller<megdnn::CondTake> dnn_op(inp->comp_node()); | |||||
| dnn_op.op->param().val = 1; | |||||
| TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())}, | |||||
| dtype::Byte()); | |||||
| auto dnn_workspace = dnn_op.create_workspace(m_layout); | |||||
| dnn_op.op->exec(inp->dev_tensor().as_megdnn(), | |||||
| msk->dev_tensor().as_megdnn(), | |||||
| dnn_workspace, | |||||
| &policy); | |||||
| } | |||||
| out.push_back(policy.at(0)); | out.push_back(policy.at(0)); | ||||
| out.push_back(policy.at(1)); | out.push_back(policy.at(1)); | ||||
| return out; | return out; | ||||
| @@ -264,6 +264,15 @@ CondTake::CondTake(VarNode *data, VarNode *mask, | |||||
| } | } | ||||
| } | } | ||||
| CondTake::NodeProp* CondTake::do_make_node_prop() const { | |||||
| auto ret = Super::do_make_node_prop(); | |||||
| ret->add_dep_type_existing_var(input(0), | |||||
| NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
| ret->add_dep_type_existing_var(input(1), | |||||
| NodeProp::DepType::VALUE_ALLOW_EMPTY); | |||||
| return ret; | |||||
| } | |||||
| #if MGB_ENABLE_GRAD | #if MGB_ENABLE_GRAD | ||||
| MGB_IMPL_OPR_GRAD(CondTake) { | MGB_IMPL_OPR_GRAD(CondTake) { | ||||
| mgb_assert(out_grad.size() == 3 && !out_grad[2]); | mgb_assert(out_grad.size() == 3 && !out_grad[2]); | ||||
| @@ -305,11 +314,21 @@ void CondTake::add_input_layout_constraint() { | |||||
| } | } | ||||
| void CondTake::scn_do_execute() { | void CondTake::scn_do_execute() { | ||||
| auto&& data = input(0)->dev_tensor(); | |||||
| auto&& mask = input(1)->dev_tensor(); | |||||
| intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()}; | intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()}; | ||||
| megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), | |||||
| input(1)->dev_tensor().as_megdnn(), | |||||
| intl::get_megdnn_workspace_from_var(output().back()), | |||||
| &dyn_malloc); | |||||
| if (data.layout().is_empty()) { | |||||
| mgb_assert(data.layout().eq_shape(mask.layout()), | |||||
| "CondTake shape differs: data=%s mask=%s", | |||||
| data.layout().TensorShape::to_string().c_str(), | |||||
| mask.layout().TensorShape::to_string().c_str()); | |||||
| dyn_malloc.alloc_output(0, data.layout().dtype, {0}, nullptr); | |||||
| dyn_malloc.alloc_output(1, dtype::Int32(), {0}, nullptr); | |||||
| } else { | |||||
| megdnn_opr()->exec(data.as_megdnn(), mask.as_megdnn(), | |||||
| intl::get_megdnn_workspace_from_var(output().back()), | |||||
| &dyn_malloc); | |||||
| } | |||||
| } | } | ||||
| /* ================= TopK ================= */ | /* ================= TopK ================= */ | ||||
| @@ -151,6 +151,7 @@ MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // { | |||||
| void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
| void scn_do_execute() override; | void scn_do_execute() override; | ||||
| void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
| NodeProp* do_make_node_prop() const override; | |||||
| public: | public: | ||||
| CondTake(VarNode *data, VarNode *mask, | CondTake(VarNode *data, VarNode *mask, | ||||
| @@ -256,20 +256,25 @@ TEST(TestOprMisc, CondTake) { | |||||
| run(mki({100})); | run(mki({100})); | ||||
| } | } | ||||
| TEST(TestOprMisc, CondTakeEmptyOut) { | |||||
| TEST(TestOprMisc, CondTakeEmptyIO) { | |||||
| using Param = opr::CondTake::Param; | using Param = opr::CondTake::Param; | ||||
| HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
| auto host_x = gen({1}); | |||||
| host_x->ptr<float>()[0] = 1; | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
| auto out = opr::CondTake::make(x, x, {Param::Mode::LT}); | |||||
| HostTensorND host_out0, host_out1; | |||||
| auto func = graph->compile({make_callback_copy(out[0], host_out0), | |||||
| make_callback_copy(out[1], host_out1)}); | |||||
| func->execute(); | |||||
| ASSERT_EQ(TensorShape{0}, host_out0.shape()); | |||||
| ASSERT_EQ(TensorShape{0}, host_out1.shape()); | |||||
| auto check = [&](const TensorShape& shp) { | |||||
| auto host_x = gen(shp); | |||||
| auto graph = ComputingGraph::make(); | |||||
| auto x = opr::Host2DeviceCopy::make(*graph, host_x); | |||||
| auto y = x + 1; | |||||
| auto out = opr::CondTake::make(x, y, {Param::Mode::EQ}); | |||||
| HostTensorND host_out0, host_out1; | |||||
| auto func = graph->compile({make_callback_copy(out[0], host_out0), | |||||
| make_callback_copy(out[1], host_out1)}); | |||||
| func->execute(); | |||||
| ASSERT_EQ(TensorShape{0}, host_out0.shape()); | |||||
| ASSERT_EQ(TensorShape{0}, host_out1.shape()); | |||||
| }; | |||||
| check({1}); | |||||
| check({0}); | |||||
| check({1, 0}); | |||||
| } | } | ||||
| TEST(TestOprMisc, TopKValueOnly) { | TEST(TestOprMisc, TopKValueOnly) { | ||||