GitOrigin-RevId: ac757ca307
tags/v1.7.0
| @@ -53,6 +53,18 @@ std::vector<TestArg> get_args() { | |||
| TensorShape{1, 3, 1, 1}, dtype::Float16()); | |||
| } | |||
| // case 3: 1 x 1 x 1 x C | |||
| for (size_t i = 4; i < 257; i *= 4) { | |||
| param::BN param; | |||
| param.fwd_mode = param::BN::FwdMode::TRAINING; | |||
| param.param_dim = param::BN::ParamDim::DIM_111C; | |||
| args.emplace_back(param, TensorShape{3, i, i, 3}, | |||
| TensorShape{1, 1, 1, 3}, dtype::Float32()); | |||
| args.emplace_back(param, TensorShape{3, i, i, 3}, | |||
| TensorShape{1, 1, 1, 3}, dtype::Float16()); | |||
| } | |||
| return args; | |||
| } | |||
| @@ -60,4 +72,4 @@ std::vector<TestArg> get_args() { | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| // vim: syntax=cpp.doxygen | |||
| // vim: syntax=cpp.doxygen | |||
| @@ -419,7 +419,7 @@ void CheckerHelper::copy_tensors_from_device(const TensorValueArray& dest, | |||
| void CheckerHelper::check_tensors(const TensorValueArray& expected, | |||
| const TensorValueArray& computed) { | |||
| for (size_t i = 0; i < expected.size(); ++i) { | |||
| if (expected[i].layout.ndim == 0) | |||
| if (expected[i].layout.ndim == 0 || m_bypass.find(i) != m_bypass.end()) | |||
| continue; | |||
| if (m_allow_invalid_check) { | |||
| MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( | |||
| @@ -69,6 +69,7 @@ protected: | |||
| std::unordered_map<size_t, RNG*> m_rng; | |||
| std::unordered_map<size_t, DType> m_dtype; | |||
| std::unordered_map<size_t, TensorFormat> m_fmt; | |||
| std::set<size_t> m_bypass; | |||
| float_t m_epsilon = 1e-3, m_max_avg_error = 1e-3, | |||
| m_max_avg_biased_error = 1e-3; | |||
| float_t m_perf_check_threshold = -1; | |||
| @@ -184,6 +185,10 @@ public: | |||
| m_rng[idx] = rng; | |||
| return *this; | |||
| } | |||
| Checker& set_bypass(size_t idx) { | |||
| m_bypass.insert(idx); | |||
| return *this; | |||
| } | |||
| //! max error of a single element | |||
| Checker& set_epsilon(dt_float32 epsilon) { | |||
| m_epsilon = epsilon; | |||
| @@ -82,6 +82,15 @@ struct DeduceLayoutProxy<Opr, 8, true> { | |||
| } | |||
| }; | |||
| template <typename Opr> | |||
| struct DeduceLayoutProxy<Opr, 9, true> { | |||
| static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { | |||
| megdnn_assert(layouts.size() == 9); | |||
| opr->deduce_layout(layouts[0], layouts[1], layouts[2], layouts[3], | |||
| layouts[4], layouts[5], layouts[6], layouts[7], | |||
| layouts[8]); | |||
| } | |||
| }; | |||
| } // namespace test | |||
| } // namespace megdnn | |||
| @@ -22,6 +22,23 @@ namespace test { | |||
| template <typename Opr, size_t Arity, bool has_workspace> | |||
| struct ExecProxy; | |||
| template <typename Opr> | |||
| struct ExecProxy<Opr, 9, true> { | |||
| WorkspaceWrapper W; | |||
| void exec(Opr* opr, const TensorNDArray& tensors) { | |||
| if (!W.valid()) { | |||
| W = WorkspaceWrapper(opr->handle(), 0); | |||
| } | |||
| W.update(opr->get_workspace_in_bytes( | |||
| tensors[0].layout, tensors[1].layout, tensors[2].layout, | |||
| tensors[3].layout, tensors[4].layout, tensors[5].layout, | |||
| tensors[6].layout, tensors[7].layout, tensors[8].layout)); | |||
| opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], | |||
| tensors[5], tensors[6], tensors[7], tensors[8], | |||
| W.workspace()); | |||
| } | |||
| }; | |||
| template <typename Opr> | |||
| struct ExecProxy<Opr, 8, true> { | |||
| WorkspaceWrapper W; | |||
| @@ -211,6 +211,10 @@ void IIDRNG::gen(const TensorND& tensor) { | |||
| } | |||
| return; | |||
| } | |||
| if (tensor.layout.dtype.enumv() == DTypeEnum::Byte) { | |||
| memset(tensor.raw_ptr, 0, tensor.layout.access_bytes()); | |||
| return; | |||
| } | |||
| megdnn_assert(0, "IIDRNG does not know how to generate value for DType %s", | |||
| tensor.layout.dtype.name()); | |||
| } | |||
| @@ -6,10 +6,13 @@ | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, | |||
| * software distributed under the License is distributed on an | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
| * implied. | |||
| */ | |||
| #include "test/cuda/fixture.h" | |||
| #include "src/cuda/batch_normalization/opr_impl.h" | |||
| #include "src/cuda/utils.h" | |||
| #include "megdnn/opr_param_defs.h" | |||
| #include "megdnn/oprs.h" | |||
| #include "test/common/bn.h" | |||
| @@ -21,15 +24,26 @@ | |||
| namespace megdnn { | |||
| namespace test { | |||
| TEST_F(CUDA, BN_FORWARD) { | |||
| TEST_F(CUDA, BN_FORWARD_BACKWARD) { | |||
| using namespace batch_normalization; | |||
| using cuda::cudnn_handle; | |||
| using cuda::batch_normalization::BNTensorDescHolder; | |||
| using cuda::batch_normalization::get_reserve_size; | |||
| std::vector<TestArg> args = get_args(); | |||
| Checker<BNForward> checker(handle_cuda()); | |||
| Checker<BNBackward> checker_bwd(handle_cuda()); | |||
| for (auto&& arg : args) { | |||
| for (int i = 0; i < 8; ++i) { | |||
| auto tensor_desc = BNTensorDescHolder({arg.src, arg.dtype}, arg.param.param_dim, | |||
| arg.param.fwd_mode); | |||
| auto reserve = get_reserve_size(cudnn_handle(handle_cuda()), tensor_desc); | |||
| // Forward | |||
| for (int i = 0; i < 9; ++i) { | |||
| checker.set_dtype(i, dtype::Float32()); | |||
| } | |||
| checker.set_dtype(0, arg.dtype); | |||
| checker.set_dtype(7, dtype::Byte()); | |||
| checker.set_dtype(8, arg.dtype); | |||
| checker.set_bypass(7); | |||
| checker.set_epsilon(1e-3).set_param(arg.param); | |||
| for (bool need_statistic : {false, true}) | |||
| checker.exec({ | |||
| @@ -40,27 +54,26 @@ TEST_F(CUDA, BN_FORWARD) { | |||
| : TensorShape({0}), // mean | |||
| need_statistic ? arg.param_shape | |||
| : TensorShape({0}), // variance | |||
| arg.param_shape, // batch_mean | |||
| arg.param_shape, // batch_inv_variance | |||
| {} // dst | |||
| arg.param_shape, // batch_mean | |||
| arg.param_shape, // batch_inv_variance | |||
| {reserve}, // reserve | |||
| arg.src // dst | |||
| }); | |||
| } | |||
| } | |||
| TEST_F(CUDA, BN_BACKWARD) { | |||
| using namespace batch_normalization; | |||
| std::vector<TestArg> args = get_args(); | |||
| Checker<BNBackward> checker(handle_cuda()); | |||
| for (auto&& arg : args) { | |||
| for (int i = 0; i < 8; ++i) { | |||
| checker.set_dtype(i, dtype::Float32()); | |||
| // Backward | |||
| for (int i = 0; i < 9; ++i) { | |||
| checker_bwd.set_dtype(i, dtype::Float32()); | |||
| } | |||
| checker.set_dtype(0, arg.dtype) // x | |||
| .set_dtype(1, arg.dtype) // dy | |||
| .set_dtype(7, arg.dtype); // dx | |||
| checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||
| checker_bwd | |||
| .set_dtype(0, arg.dtype) // x | |||
| .set_dtype(1, arg.dtype) // dy | |||
| .set_dtype(5, dtype::Byte()) // reserve | |||
| .set_dtype(8, arg.dtype) // dx | |||
| .set_bypass(5); | |||
| checker_bwd.set_epsilon(1e-3).set_param(arg.param).exec( | |||
| {arg.src, arg.src, arg.param_shape, arg.param_shape, | |||
| arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||
| arg.param_shape, {reserve}, arg.param_shape, arg.param_shape, | |||
| arg.src}); | |||
| } | |||
| } | |||
| @@ -31,6 +31,7 @@ TEST_F(ROCM, BN_FORWARD) { | |||
| checker.set_dtype(i, dtype::Float32()); | |||
| } | |||
| checker.set_dtype(0, arg.dtype); | |||
| checker.set_dtype(8, arg.dtype); | |||
| checker.set_epsilon(1e-3).set_param(arg.param); | |||
| for (bool need_statistic : {false, true}) | |||
| checker.exec({ | |||
| @@ -43,7 +44,8 @@ TEST_F(ROCM, BN_FORWARD) { | |||
| : TensorShape({0}), // variance | |||
| arg.param_shape, // batch_mean | |||
| arg.param_shape, // batch_inv_variance | |||
| {} // dst | |||
| {0}, // reserve | |||
| arg.src // dst | |||
| }); | |||
| } | |||
| } | |||
| @@ -53,15 +55,16 @@ TEST_F(ROCM, BN_BACKWARD) { | |||
| std::vector<TestArg> args = get_args(); | |||
| Checker<BNBackward> checker(handle_rocm()); | |||
| for (auto&& arg : args) { | |||
| for (int i = 0; i < 8; ++i) { | |||
| for (int i = 0; i < 9; ++i) { | |||
| checker.set_dtype(i, dtype::Float32()); | |||
| } | |||
| checker.set_dtype(0, arg.dtype) // x | |||
| .set_dtype(1, arg.dtype) // dy | |||
| .set_dtype(7, arg.dtype); // dx | |||
| .set_dtype(8, arg.dtype); // dx | |||
| checker.set_epsilon(1e-3).set_param(arg.param).exec( | |||
| {arg.src, arg.src, arg.param_shape, arg.param_shape, | |||
| arg.param_shape, arg.param_shape, arg.param_shape, arg.src}); | |||
| arg.param_shape, {0}, arg.param_shape, arg.param_shape, | |||
| arg.src}); | |||
| } | |||
| } | |||