GitOrigin-RevId: 3754dc5a65
master
| @@ -12,6 +12,7 @@ import megengine.optimizer as optim | |||
| from megengine import tensor | |||
| from megengine.autodiff import GradManager | |||
| from megengine.jit import trace | |||
| from megengine.optimizer import SGD | |||
| @contextlib.contextmanager | |||
| @@ -138,3 +139,36 @@ def test_dump_bn_train_mode(): | |||
| bn_train(data) | |||
| with pytest.raises(RuntimeError): | |||
| bn_train.dump("test.mge") | |||
| class ViTmode(M.Module): | |||
| def __init__(self, patch_size=16, in_chans=3, embed_dim=384): | |||
| super().__init__() | |||
| self.proj = M.Conv2d( | |||
| in_chans, embed_dim, kernel_size=patch_size, stride=patch_size | |||
| ) | |||
| self.head = M.Linear(embed_dim, 1000) | |||
| def forward(self, x): | |||
| x = self.proj(x) | |||
| x = F.flatten(x, 2).transpose(0, 2, 1) | |||
| x = self.head(x) | |||
| return x | |||
| def test_ViTmode_trace_train(): | |||
| model = ViTmode(embed_dim=384) | |||
| data = mge.random.normal(size=(1, 3, 224, 224)) | |||
| optim = SGD(model.parameters(), lr=0.01) | |||
| gm = GradManager() | |||
| gm.attach(model.parameters()) | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def train(): | |||
| for i in range(2): | |||
| with gm: | |||
| loss = model(data) | |||
| gm.backward(loss) | |||
| optim.step().clear_grad() | |||
| train() | |||
| @@ -22,62 +22,80 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||
| mgb_assert(inputs.size() == 2); | |||
| auto inp1 = SymbolVar{inputs[0]}, inp2 = SymbolVar{inputs[1]}; | |||
| auto dim1 = matmul.dimA, dim2 = matmul.dimB; | |||
| mgb_assert( | |||
| dim1 >= 2 && dim2 >= 2, | |||
| "the dim of one input the matmul operator dim is less than 2."); | |||
| auto cn = inputs[0]->comp_node(); | |||
| using IndexDesc = opr::Subtensor::IndexDesc; | |||
| OperatorNodeConfig config{matmul.make_name(), cn}; | |||
| DTypeScalar vi{-1}; | |||
| auto graph = inputs[0]->owner_graph(); | |||
| SymbolVar shp1_head, shp1_tail, shp2_head, shp2_tail; | |||
| if (dim1 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto shp1 = inp1.symshape(); | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = idx; | |||
| shp1_head = opr::Subtensor::make(shp1, head_desc); | |||
| auto batch = opr::Reduce::make(shp1_head, {Reduce::Mode::PRODUCT, 0}); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| shp1_tail = opr::Subtensor::make(shp1, tail_desc); | |||
| auto tshp = opr::Concat::make({batch, shp1_tail}, 0, cn); | |||
| inp1 = inp1.reshape(tshp); | |||
| } | |||
| if (dim2 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto shp2 = inp2.symshape(); | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = idx; | |||
| shp2_head = opr::Subtensor::make(shp2, head_desc); | |||
| auto batch = opr::Reduce::make(shp2_head, {Reduce::Mode::PRODUCT, 0}); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp2_tail = opr::Subtensor::make(shp2, tail_desc); | |||
| auto tshp = opr::Concat::make({batch, shp2_tail}, 0, cn); | |||
| inp2 = inp2.reshape(tshp); | |||
| if (dim1 == 2 && dim2 == 2) { | |||
| return opr::MatrixMul::make( | |||
| inp1, inp2, matmul.param(), matmul.policy(), config); | |||
| } | |||
| auto result = | |||
| opr::MatrixMul::make(inp1, inp2, matmul.param(), matmul.policy(), config); | |||
| if (dim1 > 2) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto result_shape = result.symshape(); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||
| auto tshp = opr::Concat::make({shp1_head, shp_tail}, 0, cn); | |||
| result = result.reshape(tshp); | |||
| } | |||
| if (dim2 > 2) { | |||
| //! use batched matrix mul | |||
| SymbolVar shp_head, batch; | |||
| DTypeScalar vi{-2}; | |||
| auto compress_shape = [&](SymbolVar inp) { | |||
| if (inp.shape().ndim > 3) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto shp = inp.symshape(); | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = idx; | |||
| shp_head = opr::Subtensor::make(shp, head_desc); | |||
| batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0}); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp_tail = opr::Subtensor::make(shp, tail_desc); | |||
| auto tshp = opr::Concat::make({batch, shp_tail}, 0, cn); | |||
| return inp.reshape(tshp); | |||
| } else if (inp.shape().ndim == 3) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto shp = inp.symshape(); | |||
| IndexDesc head_desc(1); | |||
| head_desc[0].end = idx; | |||
| shp_head = opr::Subtensor::make(shp, head_desc); | |||
| batch = opr::Reduce::make(shp_head, {Reduce::Mode::PRODUCT, 0}); | |||
| return inp; | |||
| } else { | |||
| return inp; | |||
| } | |||
| }; | |||
| inp1 = compress_shape(inp1); | |||
| inp2 = compress_shape(inp2); | |||
| auto expand_shape = [&](SymbolVar inp) { | |||
| if (inp.shape().ndim < 3) { | |||
| auto shp = inp.symshape(); | |||
| using Desc = opr::AxisAddRemove::AxisDesc; | |||
| std::vector<Desc> add_axis_param; | |||
| add_axis_param.push_back(Desc::make_add(0)); | |||
| auto out = opr::AxisAddRemove::make(inp, add_axis_param); | |||
| auto target_shape = opr::Concat::make({batch, shp}, 0, cn); | |||
| return opr::Broadcast::make(out, target_shape); | |||
| } else { | |||
| return inp; | |||
| } | |||
| }; | |||
| inp1 = expand_shape(inp1); | |||
| inp2 = expand_shape(inp2); | |||
| auto result = opr::BatchedMatrixMul::make( | |||
| inp1, inp2, matmul.param(), matmul.policy(), config); | |||
| size_t max_dim = std::max(dim1, dim2); | |||
| if (max_dim > 3) { | |||
| auto idx = opr::ImmutableTensor::make(*graph, vi, config); | |||
| auto result_shape = result.symshape(); | |||
| auto res_shp = result.symshape(); | |||
| IndexDesc tail_desc(1); | |||
| tail_desc[0].begin = idx; | |||
| auto shp_tail = opr::Subtensor::make(result_shape, tail_desc); | |||
| auto tshp = opr::Concat::make({shp2_head, shp_tail}, 0, cn); | |||
| auto tail_shape = opr::Subtensor::make(res_shp, tail_desc); | |||
| auto tshp = opr::Concat::make({shp_head, tail_shape}, 0, cn); | |||
| result = result.reshape(tshp); | |||
| } | |||
| return result; | |||
| } | |||
| @@ -150,6 +150,31 @@ void OprChecker::run(std::vector<InputSpec> inp_keys, std::set<size_t> bypass) { | |||
| } | |||
| } | |||
| VarNodeArray OprChecker::run_apply_on_var_node(std::vector<InputSpec> inp_keys) { | |||
| HostTensorGenerator<> gen; | |||
| size_t nr_inps = inp_keys.size(); | |||
| SmallVector<HostTensorND> host_inp(nr_inps); | |||
| VarNodeArray sym_inp(nr_inps); | |||
| auto graph = ComputingGraph::make(); | |||
| graph->options().graph_opt_level = 0; | |||
| for (size_t i = 0; i < nr_inps; ++i) { | |||
| // TODO: remove std::visit for support osx 10.12 | |||
| host_inp[i] = std::visit( | |||
| [&gen](auto&& arg) -> HostTensorND { | |||
| using T = std::decay_t<decltype(arg)>; | |||
| if constexpr (std::is_same_v<TensorShape, T>) { | |||
| return *gen(arg); | |||
| } else { | |||
| static_assert(std::is_same_v<HostTensorND, T>); | |||
| return arg; | |||
| } | |||
| }, | |||
| inp_keys[i]); | |||
| sym_inp[i] = opr::SharedDeviceTensor::make(*graph, host_inp[i]).node(); | |||
| } | |||
| return OpDef::apply_on_var_node(*m_op, sym_inp); | |||
| } | |||
| TEST(TestHelper, PyModule) { | |||
| py::module m = PyEnv::get(); | |||
| py::print(m); | |||
| @@ -14,6 +14,9 @@ public: | |||
| OprChecker(std::shared_ptr<OpDef> opdef); | |||
| void run(std::vector<InputSpec> inp_shapes, std::set<size_t> bypass = {}); | |||
| //! test the interface of apply_on_var_node | |||
| VarNodeArray run_apply_on_var_node(std::vector<InputSpec> inp_shapes); | |||
| private: | |||
| std::shared_ptr<OpDef> m_op; | |||
| }; | |||
| @@ -1,9 +1,11 @@ | |||
| #include "./helper.h" | |||
| #include "megbrain/comp_node_env.h" | |||
| #include "megbrain/imperative/blob_manager.h" | |||
| #include "megbrain/imperative/ops/autogen.h" | |||
| #include "megbrain/imperative/ops/opr_attr.h" | |||
| #include "megbrain/opr/basic_arith.h" | |||
| #include "megbrain/opr/basic_arith_wrapper.h" | |||
| #include "megbrain/opr/blas.h" | |||
| #include "megbrain/opr/dnn/batch_norm.h" | |||
| #include "megbrain/opr/dnn/convolution.h" | |||
| #include "megbrain/opr/tensor_manip.h" | |||
| @@ -164,4 +166,58 @@ TEST(TestImperative, Defragment) { | |||
| } | |||
| #endif // MGB_CUDA && MGB_ENABLE_EXCEPTION | |||
| TEST(TestImperative, MatrixMulApplyOnVarNode) { | |||
| using Param = opr::MatrixMul::Param; | |||
| Param param; | |||
| std::vector<std::pair<TensorShape, TensorShape>> shapes; | |||
| std::vector<TensorShape> target_shapes; | |||
| std::vector<Param> params; | |||
| //! testcase 0 | |||
| params.push_back(param); | |||
| shapes.push_back({TensorShape{10, 5}, TensorShape{5, 10}}); | |||
| target_shapes.push_back(TensorShape{10, 10}); | |||
| //! testcase 1 | |||
| params.push_back(param); | |||
| shapes.push_back({TensorShape{3, 10, 5}, TensorShape{5, 10}}); | |||
| target_shapes.push_back(TensorShape{3, 10, 10}); | |||
| //! testcase 2 | |||
| param.transposeA = true; | |||
| param.transposeB = false; | |||
| params.push_back(param); | |||
| shapes.push_back({TensorShape{3, 7, 6}, TensorShape{7, 10}}); | |||
| target_shapes.push_back(TensorShape{3, 6, 10}); | |||
| //! testcase 3 | |||
| param.transposeA = true; | |||
| param.transposeB = false; | |||
| params.push_back(param); | |||
| shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{7, 10}}); | |||
| target_shapes.push_back(TensorShape{2, 3, 6, 10}); | |||
| //! testcase 4 | |||
| param.transposeA = false; | |||
| param.transposeB = true; | |||
| params.push_back(param); | |||
| shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{2, 3, 8, 6}}); | |||
| target_shapes.push_back(TensorShape{2, 3, 7, 8}); | |||
| //! testcase 5 | |||
| param.transposeA = false; | |||
| param.transposeB = true; | |||
| params.push_back(param); | |||
| shapes.push_back({TensorShape{2, 3, 7, 6}, TensorShape{8, 6}}); | |||
| target_shapes.push_back(TensorShape{2, 3, 7, 8}); | |||
| for (size_t i = 0; i < params.size(); i++) { | |||
| auto& shape = shapes[i]; | |||
| auto op = MatrixMul::make( | |||
| params[i], ::megdnn::param::ExecutionPolicy{}, shape.first.ndim, | |||
| shape.second.ndim); | |||
| auto result = OprChecker(op).run_apply_on_var_node({shape.first, shape.second}); | |||
| ASSERT_GT(result.size(), 0); | |||
| ASSERT_EQ(target_shapes[i].ndim, result[0]->shape().ndim); | |||
| for (size_t id = 0; id < target_shapes[i].ndim; id++) { | |||
| ASSERT_EQ(target_shapes[i][id], result[0]->shape()[id]); | |||
| } | |||
| } | |||
| } | |||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||