GitOrigin-RevId: 0d042dbfce
tags/v1.5.0
| @@ -1317,6 +1317,27 @@ protected: | |||||
| TensorLayout& exec_workspace, | TensorLayout& exec_workspace, | ||||
| TensorLayout& exec_src, TensorLayout& exec_dst); | TensorLayout& exec_src, TensorLayout& exec_dst); | ||||
| }; | }; | ||||
| /*! | |||||
| * \brief check whether input contains inf value. | |||||
| */ | |||||
| class CheckHasInf: public OperatorBase { | |||||
| DEF_OPR_PARAM(Empty); | |||||
| DEF_OPR_IMPL(CheckHasInf, OperatorBase, 1, 1); | |||||
| public: | |||||
| virtual size_t get_workspace_in_bytes(const TensorLayout &src, | |||||
| const TensorLayout &dst) = 0; | |||||
| void deduce_layout(const TensorLayout &src, TensorLayout &dst); | |||||
| virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) = 0; | |||||
| protected: | |||||
| void check_exec(const TensorLayout &src, const TensorLayout &dst, | |||||
| size_t workspace_in_bytes); | |||||
| }; | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| #include "megdnn/internal/opr_header_epilogue.h" | #include "megdnn/internal/opr_header_epilogue.h" | ||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * \file dnn/src/common/check_has_inf.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megdnn/oprs.h" | |||||
| #include "src/common/utils.h" | |||||
| namespace megdnn { | |||||
| void CheckHasInf::check_exec(const TensorLayout& src, const TensorLayout& dst, | |||||
| size_t workspace_in_bytes) { | |||||
| megdnn_assert_contiguous(src); | |||||
| megdnn_assert_contiguous(dst); | |||||
| megdnn_assert(src.ndim == 1); | |||||
| megdnn_assert(src.dtype == dtype::Float32()); | |||||
| auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst); | |||||
| megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); | |||||
| } | |||||
| void CheckHasInf::deduce_layout(const TensorLayout&, TensorLayout& dst) { | |||||
| dst.shape[0] = 1; | |||||
| dst.ndim = 1; | |||||
| dst.dtype = dtype::Int32(); | |||||
| dst.init_contiguous_stride(); | |||||
| } | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -207,7 +207,8 @@ private: | |||||
| cb(FakeQuantForward) \ | cb(FakeQuantForward) \ | ||||
| cb(FakeQuantBackward) \ | cb(FakeQuantBackward) \ | ||||
| cb(TQTForward) \ | cb(TQTForward) \ | ||||
| cb(TQTBackward) | |||||
| cb(TQTBackward) \ | |||||
| cb(CheckHasInf) | |||||
| /*! | /*! | ||||
| * \brief specialize HandleImpl::create_operator for a single opr type; | * \brief specialize HandleImpl::create_operator for a single opr type; | ||||
| @@ -120,6 +120,7 @@ DEF(PowC, 2, false, true); | |||||
| DEF(UniformRNG, 1, true, true); | DEF(UniformRNG, 1, true, true); | ||||
| DEF(GaussianRNG, 1, true, true); | DEF(GaussianRNG, 1, true, true); | ||||
| DEF(ChecksumForward, 1, true, false); | DEF(ChecksumForward, 1, true, false); | ||||
| DEF(CheckHasInf, 2, true, true); | |||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -4,9 +4,9 @@ | |||||
| * | * | ||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | ||||
| * | * | ||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | */ | ||||
| #pragma once | #pragma once | ||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| @@ -151,6 +151,33 @@ struct MaxOp { | |||||
| : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | : INIT(wtype(DTypeTrait<wtype>::min())), src(src), dst(dst), B(B) {} | ||||
| }; | }; | ||||
| template <typename src_ctype, typename dst_ctype, typename wtype_> | |||||
| struct CheckHasInfOp { | |||||
| typedef wtype_ wtype; | |||||
| const wtype INIT; | |||||
| src_ctype* src; | |||||
| dst_ctype* dst; | |||||
| const size_t B; | |||||
| MEGDNN_HOST MEGDNN_DEVICE wtype read(uint32_t idx) { | |||||
| #if defined(__CUDA_ARCH__) | |||||
| return isinf(src[idx]); | |||||
| #else | |||||
| return std::isinf(src[idx]); | |||||
| #endif | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE void write(uint32_t idx, wtype val) { | |||||
| dst[idx] = val; | |||||
| } | |||||
| static MEGDNN_HOST MEGDNN_DEVICE wtype apply(wtype lhs, wtype rhs) { | |||||
| return lhs | rhs; | |||||
| } | |||||
| MEGDNN_HOST MEGDNN_DEVICE CheckHasInfOp(src_ctype* src, dst_ctype* dst, | |||||
| size_t B) | |||||
| : INIT(wtype(0)), src(src), dst(dst), B(B) {} | |||||
| }; | |||||
| #if MEGDNN_CC_HOST | #if MEGDNN_CC_HOST | ||||
| void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, | void get_ABC(const TensorShape& shape, size_t& A, size_t& B, size_t& C, | ||||
| size_t axis); | size_t axis); | ||||
| @@ -0,0 +1,27 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/check_has_inf/kern.cu | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, | |||||
| * software distributed under the License is distributed on an | |||||
| * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/common/reduce_helper.h" | |||||
| #include "megdnn/dtype.h" | |||||
| #include "src/cuda/reduce_helper.cuh" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| #define COMMA , | |||||
| INST_REDUCE(reduce::CheckHasInfOp<dt_float32 COMMA dt_int32 COMMA dt_int32>, false); | |||||
| #undef COMMA | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: ft=cpp syntax=cpp.doxygen | |||||
| @@ -0,0 +1,45 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/check_has_inf/opr_impl.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/cuda/check_has_inf/opr_impl.h" | |||||
| #include "src/cuda/reduce_helper.cuh" | |||||
| #include "src/cuda/handle.h" | |||||
| #include "src/cuda/utils.h" | |||||
| #include "src/common/reduce_helper.h" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| using reduce::CheckHasInfOp; | |||||
| size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) { | |||||
| typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||||
| return get_reduce_workspace_in_bytes<Op>(1, src.total_nr_elems(), 1); | |||||
| } | |||||
| void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(src.layout, dst.layout, workspace.size); | |||||
| typedef CheckHasInfOp<dt_float32, dt_int32, dt_int32> Op; | |||||
| auto stream = cuda_stream(this->handle()); | |||||
| auto B = src.layout.total_nr_elems(); | |||||
| return run_reduce<Op, false>( | |||||
| workspace.ptr<dt_int32>(), 1, B, 1, stream, | |||||
| Op(src.ptr<dt_float32>(), dst.ptr<dt_int32>(), B)); | |||||
| } | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * \file dnn/src/cuda/check_has_inf/opr_impl.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs/utils.h" | |||||
| #include "src/cuda/utils.h" | |||||
| namespace megdnn { | |||||
| namespace cuda { | |||||
| class CheckHasInfImpl final : public CheckHasInf { | |||||
| public: | |||||
| using CheckHasInf::CheckHasInf; | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) override; | |||||
| bool is_thread_safe() const override { return true; } | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) override; | |||||
| }; | |||||
| } // namespace cuda | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -17,6 +17,7 @@ | |||||
| #include "src/cuda/argsort/opr_impl.h" | #include "src/cuda/argsort/opr_impl.h" | ||||
| #include "src/cuda/batch_normalization/opr_impl.h" | #include "src/cuda/batch_normalization/opr_impl.h" | ||||
| #include "src/cuda/batched_matrix_mul/opr_impl.h" | #include "src/cuda/batched_matrix_mul/opr_impl.h" | ||||
| #include "src/cuda/check_has_inf/opr_impl.h" | |||||
| #include "src/cuda/checksum/opr_impl.h" | #include "src/cuda/checksum/opr_impl.h" | ||||
| #include "src/cuda/concat/opr_impl.h" | #include "src/cuda/concat/opr_impl.h" | ||||
| #include "src/cuda/cond_take/opr_impl.h" | #include "src/cuda/cond_take/opr_impl.h" | ||||
| @@ -18,15 +18,15 @@ namespace cuda { | |||||
| using namespace reduce; | using namespace reduce; | ||||
| #define COMMOA , | |||||
| #define COMMA , | |||||
| #define INST(sctype, dctype, wtype) \ | #define INST(sctype, dctype, wtype) \ | ||||
| INST_REDUCE(SumOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(SumSqrOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(ProdOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(MinOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(MaxOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(MeanOp<sctype COMMOA dctype COMMOA wtype>, false); | |||||
| INST_REDUCE(SumOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(SumSqrOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(ProdOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(MinOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(MaxOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(MeanOp<sctype COMMA dctype COMMA wtype>, false); | |||||
| #define cb(_dt) \ | #define cb(_dt) \ | ||||
| INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) | INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) | ||||
| @@ -40,6 +40,7 @@ INST(int, float, float) | |||||
| #undef cb | #undef cb | ||||
| #undef INST | #undef INST | ||||
| #undef COMMA | |||||
| } // namespace cuda | } // namespace cuda | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/check_has_inf/opr_impl.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "src/naive/check_has_inf/opr_impl.h" | |||||
| #include "src/common/utils.h" | |||||
| #include "src/naive/handle.h" | |||||
| namespace { | |||||
| using namespace megdnn; | |||||
| #define src_ctype dt_float32 | |||||
| #define wtype dt_int32 | |||||
| void reduce_fwd(const src_ctype* sptr, wtype* dptr, size_t size) { | |||||
| std::function<wtype(size_t, size_t)> func; | |||||
| func = [&](size_t l, size_t r) -> wtype { | |||||
| if (l + 1 < r) { | |||||
| size_t mid = l + (r - l) / 2; | |||||
| return func(l, mid) | func(mid, r); | |||||
| } else { | |||||
| return static_cast<wtype>(std::isinf(sptr[l])); | |||||
| } | |||||
| }; | |||||
| dptr[0] = func(0, size); | |||||
| } | |||||
| } // namespace | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| size_t CheckHasInfImpl::get_workspace_in_bytes(const TensorLayout&, | |||||
| const TensorLayout&) { | |||||
| return 0; | |||||
| } | |||||
| void CheckHasInfImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) { | |||||
| check_exec(src.layout, dst.layout, workspace.size); | |||||
| auto handle = static_cast<HandleImpl*>(this->handle()); | |||||
| MEGDNN_DISPATCH_CPU_KERN( | |||||
| handle, reduce_fwd(src.ptr<dt_float32>(), dst.ptr<dt_int32>(), | |||||
| src.layout.total_nr_elems())); | |||||
| } | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * \file dnn/src/naive/check_has_inf/opr_impl.h | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #pragma once | |||||
| #include "megdnn/oprs.h" | |||||
| namespace megdnn { | |||||
| namespace naive { | |||||
| class CheckHasInfImpl final : public CheckHasInf { | |||||
| public: | |||||
| using CheckHasInf::CheckHasInf; | |||||
| bool is_thread_safe() const override { return true; } | |||||
| size_t get_workspace_in_bytes(const TensorLayout& src, | |||||
| const TensorLayout& dst) override; | |||||
| void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst, | |||||
| _megdnn_workspace workspace) override; | |||||
| }; | |||||
| } // namespace naive | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen | |||||
| @@ -21,6 +21,7 @@ | |||||
| #include "src/naive/batch_conv_bias/opr_impl.h" | #include "src/naive/batch_conv_bias/opr_impl.h" | ||||
| #include "src/naive/batch_normalization/opr_impl.h" | #include "src/naive/batch_normalization/opr_impl.h" | ||||
| #include "src/naive/batched_matrix_mul/opr_impl.h" | #include "src/naive/batched_matrix_mul/opr_impl.h" | ||||
| #include "src/naive/check_has_inf/opr_impl.h" | |||||
| #include "src/naive/checksum/opr_impl.h" | #include "src/naive/checksum/opr_impl.h" | ||||
| #include "src/naive/concat/opr_impl.h" | #include "src/naive/concat/opr_impl.h" | ||||
| #include "src/naive/cond_take/opr_impl.h" | #include "src/naive/cond_take/opr_impl.h" | ||||
| @@ -18,15 +18,15 @@ namespace rocm { | |||||
| using namespace reduce; | using namespace reduce; | ||||
| #define COMMOA , | |||||
| #define COMMA , | |||||
| #define INST(sctype, dctype, wtype) \ | #define INST(sctype, dctype, wtype) \ | ||||
| INST_REDUCE(SumOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(SumSqrOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(ProdOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(MinOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(MaxOp<sctype COMMOA dctype COMMOA wtype>, false); \ | |||||
| INST_REDUCE(MeanOp<sctype COMMOA dctype COMMOA wtype>, false); | |||||
| INST_REDUCE(SumOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(SumSqrOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(ProdOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(MinOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(MaxOp<sctype COMMA dctype COMMA wtype>, false); \ | |||||
| INST_REDUCE(MeanOp<sctype COMMA dctype COMMA wtype>, false); | |||||
| #define cb(_dt) \ | #define cb(_dt) \ | ||||
| INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) | INST(DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype, DTypeTrait<_dt>::ctype) | ||||
| @@ -39,6 +39,7 @@ INST(float, dt_float16, float) | |||||
| INST(int, float, float) | INST(int, float, float) | ||||
| #undef cb | #undef cb | ||||
| #undef INST | #undef INST | ||||
| #undef COMMA | |||||
| } // namespace rocm | } // namespace rocm | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -23,7 +23,7 @@ namespace { | |||||
| ::testing::AssertionResult assert_tensor_eq_with_iter( | ::testing::AssertionResult assert_tensor_eq_with_iter( | ||||
| const char *expr0, const char *expr1, | const char *expr0, const char *expr1, | ||||
| Iter it0, Iter it1, const TensorLayout &layout, | Iter it0, Iter it1, const TensorLayout &layout, | ||||
| float maxerr, float maxerr_avg, float maxerr_avg_biased) { | |||||
| float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | |||||
| auto nr_elem = layout.total_nr_elems(); | auto nr_elem = layout.total_nr_elems(); | ||||
| double error_sum = 0; | double error_sum = 0; | ||||
| @@ -33,8 +33,8 @@ namespace { | |||||
| float err = diff(iv0, iv1); | float err = diff(iv0, iv1); | ||||
| error_sum += std::abs(err); | error_sum += std::abs(err); | ||||
| error_sum_biased += err; | error_sum_biased += err; | ||||
| if (!good_float(iv0) || !good_float(iv1) || | |||||
| std::abs(err) > maxerr) { | |||||
| if (!allow_invalid && (!good_float(iv0) || !good_float(iv1) || | |||||
| std::abs(err) > maxerr)) { | |||||
| Index index(layout, i); | Index index(layout, i); | ||||
| return ::testing::AssertionFailure() | return ::testing::AssertionFailure() | ||||
| << "Unequal value\n" | << "Unequal value\n" | ||||
| @@ -82,14 +82,14 @@ namespace { | |||||
| ::testing::AssertionResult assert_tensor_eq_with_dtype( | ::testing::AssertionResult assert_tensor_eq_with_dtype( | ||||
| const char *expr0, const char *expr1, | const char *expr0, const char *expr1, | ||||
| const TensorND &v0, const TensorND &v1, | const TensorND &v0, const TensorND &v1, | ||||
| float maxerr, float maxerr_avg, float maxerr_avg_biased) { | |||||
| float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | |||||
| if (!std::is_same<ctype, dt_qint4>::value && | if (!std::is_same<ctype, dt_qint4>::value && | ||||
| !std::is_same<ctype, dt_quint4>::value) { | !std::is_same<ctype, dt_quint4>::value) { | ||||
| if (v0.layout.is_physical_contiguous() && | if (v0.layout.is_physical_contiguous() && | ||||
| v1.layout.is_physical_contiguous()) { | v1.layout.is_physical_contiguous()) { | ||||
| return assert_tensor_eq_with_iter<ctype>( | return assert_tensor_eq_with_iter<ctype>( | ||||
| expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), | expr0, expr1, v0.ptr<ctype>(), v1.ptr<ctype>(), | ||||
| v0.layout, maxerr, maxerr_avg, maxerr_avg_biased); | |||||
| v0.layout, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid); | |||||
| } | } | ||||
| } | } | ||||
| @@ -98,7 +98,7 @@ namespace { | |||||
| return assert_tensor_eq_with_iter<ctype>(expr0, expr1, it0, it1, | return assert_tensor_eq_with_iter<ctype>(expr0, expr1, it0, it1, | ||||
| v0.layout, maxerr, maxerr_avg, | v0.layout, maxerr, maxerr_avg, | ||||
| maxerr_avg_biased); | |||||
| maxerr_avg_biased, allow_invalid); | |||||
| } | } | ||||
| template<class Impl> | template<class Impl> | ||||
| @@ -136,7 +136,7 @@ namespace { | |||||
| const char* /*expr_maxerr_avg*/, | const char* /*expr_maxerr_avg*/, | ||||
| const char* /*expr_maxerr_avg*/, | const char* /*expr_maxerr_avg*/, | ||||
| const TensorND &v0, const TensorND &v1, | const TensorND &v0, const TensorND &v1, | ||||
| float maxerr, float maxerr_avg, float maxerr_avg_biased) { | |||||
| float maxerr, float maxerr_avg, float maxerr_avg_biased, bool allow_invalid) { | |||||
| if (!v0.layout.eq_shape(v1.layout)) { | if (!v0.layout.eq_shape(v1.layout)) { | ||||
| return ::testing::AssertionFailure() | return ::testing::AssertionFailure() | ||||
| @@ -160,7 +160,7 @@ namespace { | |||||
| #define cb(_dt) \ | #define cb(_dt) \ | ||||
| case DTypeTrait<_dt>::enumv: \ | case DTypeTrait<_dt>::enumv: \ | ||||
| return assert_tensor_eq_with_dtype<DTypeTrait<_dt>::ctype>( \ | return assert_tensor_eq_with_dtype<DTypeTrait<_dt>::ctype>( \ | ||||
| expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased); | |||||
| expr0, expr1, v0, v1, maxerr, maxerr_avg, maxerr_avg_biased, allow_invalid); | |||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | MEGDNN_FOREACH_QUANTIZED_DTYPE(cb) | ||||
| //! In order to avoid an unnecessary increase in binary size, we just | //! In order to avoid an unnecessary increase in binary size, we just | ||||
| @@ -174,6 +174,17 @@ namespace { | |||||
| } | } | ||||
| ::testing::AssertionResult test::__assert_tensor_eq_allow_invalid( | |||||
| const char* expr0, const char* expr1, const char* expr_maxerr, | |||||
| const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, | |||||
| const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, | |||||
| float maxerr_avg_biased) { | |||||
| return __assert_tensor_eq(expr0, expr1, expr_maxerr, expr_maxerr_avg, | |||||
| expr_maxerr_avg_biased, v0, v1, maxerr, | |||||
| maxerr_avg, maxerr_avg_biased, true); | |||||
| }; | |||||
| CheckerHelper::CheckerHelper(Handle *handle, bool check_dispatch): | CheckerHelper::CheckerHelper(Handle *handle, bool check_dispatch): | ||||
| m_handle_cur(handle), | m_handle_cur(handle), | ||||
| m_default_rng(new NormalRNG()) | m_default_rng(new NormalRNG()) | ||||
| @@ -411,9 +422,15 @@ void CheckerHelper::check_tensors(const TensorValueArray& expected, | |||||
| for (size_t i = 0; i < expected.size(); ++i) { | for (size_t i = 0; i < expected.size(); ++i) { | ||||
| if (expected[i].layout.ndim == 0) | if (expected[i].layout.ndim == 0) | ||||
| continue; | continue; | ||||
| MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon, | |||||
| m_max_avg_error, | |||||
| m_max_avg_biased_error); | |||||
| if (m_allow_invalid_check) { | |||||
| MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( | |||||
| expected[i], computed[i], m_epsilon, m_max_avg_error, | |||||
| m_max_avg_biased_error); | |||||
| } else { | |||||
| MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(expected[i], computed[i], m_epsilon, | |||||
| m_max_avg_error, | |||||
| m_max_avg_biased_error); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -79,6 +79,7 @@ protected: | |||||
| bool m_no_naive_and_check = false; | bool m_no_naive_and_check = false; | ||||
| bool m_stable_check = false; | bool m_stable_check = false; | ||||
| bool m_force_deduce_dst = true; | bool m_force_deduce_dst = true; | ||||
| bool m_allow_invalid_check = false; | |||||
| /** | /** | ||||
| * the offset from the start of malloc memory | * the offset from the start of malloc memory | ||||
| * | * | ||||
| @@ -248,6 +249,11 @@ public: | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| Checker& set_allow_invalid_check(bool allow_invalid_check) { | |||||
| m_allow_invalid_check = allow_invalid_check; | |||||
| return *this; | |||||
| } | |||||
| //! load input tensors from file for next run | //! load input tensors from file for next run | ||||
| Checker& load_input_tensors(const char* fpath) { | Checker& load_input_tensors(const char* fpath) { | ||||
| m_input_tensors_fpath = fpath; | m_input_tensors_fpath = fpath; | ||||
| @@ -326,6 +332,12 @@ private: | |||||
| }; | }; | ||||
| ::testing::AssertionResult __assert_tensor_eq( | ::testing::AssertionResult __assert_tensor_eq( | ||||
| const char* expr0, const char* expr1, const char* expr_maxerr, | |||||
| const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, | |||||
| const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, | |||||
| float maxerr_avg_biased, bool allow_invalid = false); | |||||
| ::testing::AssertionResult __assert_tensor_eq_allow_invalid( | |||||
| const char* expr0, const char* expr1, const char* expr_maxerr, | const char* expr0, const char* expr1, const char* expr_maxerr, | ||||
| const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, | const char* expr_maxerr_avg, const char* expr_maxerr_avg_biased, | ||||
| const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, | const TensorND& v0, const TensorND& v1, float maxerr, float maxerr_avg, | ||||
| @@ -336,6 +348,11 @@ private: | |||||
| ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq, v0, v1, maxerr, \ | ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq, v0, v1, maxerr, \ | ||||
| maxerr_avg, maxerr_avg_biased) | maxerr_avg, maxerr_avg_biased) | ||||
| #define MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG_ALLOW_INVALID( \ | |||||
| v0, v1, maxerr, maxerr_avg, maxerr_avg_biased) \ | |||||
| ASSERT_PRED_FORMAT5(::megdnn::test::__assert_tensor_eq_allow_invalid, v0, \ | |||||
| v1, maxerr, maxerr_avg, maxerr_avg_biased) | |||||
| #define MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, maxerr) \ | #define MEGDNN_ASSERT_TENSOR_EQ_EPS(v0, v1, maxerr) \ | ||||
| MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr, maxerr) | MEGDNN_ASSERT_TENSOR_EQ_EPS_AVG(v0, v1, maxerr, maxerr, maxerr) | ||||
| @@ -435,7 +452,7 @@ TensorND TensorValue(const TensorShape& shape, T dtype, | |||||
| template <typename T, typename U> | template <typename T, typename U> | ||||
| TensorND TensorValueLowbit4(const TensorShape& shape, T dtype, | TensorND TensorValueLowbit4(const TensorShape& shape, T dtype, | ||||
| std::vector<U> values) { | |||||
| std::vector<U> values) { | |||||
| TensorND tensor; | TensorND tensor; | ||||
| tensor.layout = {shape, dtype}; | tensor.layout = {shape, dtype}; | ||||
| tensor.raw_ptr = | tensor.raw_ptr = | ||||
| @@ -38,6 +38,22 @@ struct ExecProxy<Opr, 8, true> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct ExecProxy<Opr, 7, true> { | |||||
| WorkspaceWrapper W; | |||||
| void exec(Opr* opr, const TensorNDArray& tensors) { | |||||
| if (!W.valid()) { | |||||
| W = WorkspaceWrapper(opr->handle(), 0); | |||||
| } | |||||
| W.update(opr->get_workspace_in_bytes( | |||||
| tensors[0].layout, tensors[1].layout, tensors[2].layout, | |||||
| tensors[3].layout, tensors[4].layout, tensors[5].layout, | |||||
| tensors[6].layout)); | |||||
| opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], | |||||
| tensors[5], tensors[6], W.workspace()); | |||||
| } | |||||
| }; | |||||
| template <typename Opr> | template <typename Opr> | ||||
| struct ExecProxy<Opr, 6, true> { | struct ExecProxy<Opr, 6, true> { | ||||
| WorkspaceWrapper W; | WorkspaceWrapper W; | ||||
| @@ -149,24 +165,6 @@ struct ExecProxy<Opr, 2, false> { | |||||
| } | } | ||||
| }; | }; | ||||
| template <typename Opr> | |||||
| struct ExecProxy<Opr, 7, true> { | |||||
| WorkspaceWrapper W; | |||||
| void exec(Opr* opr, const TensorNDArray& tensors) { | |||||
| if (!W.valid()) { | |||||
| W = WorkspaceWrapper(opr->handle(), 0); | |||||
| } | |||||
| W.update(opr->get_workspace_in_bytes( | |||||
| tensors[0].layout, tensors[1].layout, tensors[2].layout, | |||||
| tensors[3].layout, tensors[4].layout, tensors[5].layout, | |||||
| tensors[6].layout)); | |||||
| opr->exec(tensors[0], tensors[1], tensors[2], tensors[3], tensors[4], | |||||
| tensors[5], tensors[6], W.workspace()); | |||||
| } | |||||
| }; | |||||
| } // namespace test | } // namespace test | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| // vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen | ||||
| @@ -301,9 +301,8 @@ void UniformFloatNonZeroRNG::fill_fast_float32(dt_float32* dest, size_t size) { | |||||
| } | } | ||||
| } | } | ||||
| void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) { | |||||
| void UniformFloatWithValueRNG::fill_fast_float32(dt_float32 *dest, size_t size) { | |||||
| RNGxorshf gen{RandomState::generator()}; | RNGxorshf gen{RandomState::generator()}; | ||||
| printf("a %f, b %f \n", m_dist.a(), m_dist.b()); | |||||
| auto k = double(m_dist.b() - m_dist.a()) / | auto k = double(m_dist.b() - m_dist.a()) / | ||||
| double(RNGxorshf::max() - RNGxorshf::min() + 1.0); | double(RNGxorshf::max() - RNGxorshf::min() + 1.0); | ||||
| auto b = m_dist.a() - RNGxorshf::min() * k; | auto b = m_dist.a() - RNGxorshf::min() * k; | ||||
| @@ -312,9 +311,8 @@ void UniformFloatWithZeroRNG::fill_fast_float32(dt_float32 *dest, size_t size) { | |||||
| auto pb = 0.f - RNGxorshf::min() * p; | auto pb = 0.f - RNGxorshf::min() * p; | ||||
| for (size_t i = 0; i < size; ++ i) { | for (size_t i = 0; i < size; ++ i) { | ||||
| float rnd = gen() * p + pb; | float rnd = gen() * p + pb; | ||||
| //printf("%.3f \n", rnd); | |||||
| if(rnd < zero_val_proportion_) { | |||||
| dest[i] = 0.f; | |||||
| if(rnd < val_proportion_) { | |||||
| dest[i] = val_; | |||||
| } else { | } else { | ||||
| dest[i] = gen() * k + b; | dest[i] = gen() * k + b; | ||||
| } | } | ||||
| @@ -11,10 +11,10 @@ | |||||
| #pragma once | #pragma once | ||||
| #include "megdnn/dtype.h" | #include "megdnn/dtype.h" | ||||
| #include "test/common/utils.h" | |||||
| #include "test/common/random_state.h" | |||||
| #include <random> | #include <random> | ||||
| #include <set> | #include <set> | ||||
| #include "test/common/random_state.h" | |||||
| #include "test/common/utils.h" | |||||
| namespace megdnn { | namespace megdnn { | ||||
| namespace test { | namespace test { | ||||
| @@ -80,7 +80,8 @@ public: | |||||
| } | } | ||||
| void gen(const TensorND& tensor) override { | void gen(const TensorND& tensor) override { | ||||
| megdnn_assert(tensor.layout.dtype.enumv() == DTypeTrait<dt_bfloat16>::enumv); | |||||
| megdnn_assert(tensor.layout.dtype.enumv() == | |||||
| DTypeTrait<dt_bfloat16>::enumv); | |||||
| size_t nr_elems = tensor.layout.span().dist_elem(); | size_t nr_elems = tensor.layout.span().dist_elem(); | ||||
| auto offset = tensor.layout.span().low_elem; | auto offset = tensor.layout.span().low_elem; | ||||
| for (size_t i = 0; i < nr_elems; ++i) { | for (size_t i = 0; i < nr_elems; ++i) { | ||||
| @@ -185,24 +186,31 @@ public: | |||||
| void fill_fast_float32(dt_float32* dest, size_t size) override; | void fill_fast_float32(dt_float32* dest, size_t size) override; | ||||
| }; | }; | ||||
| class UniformFloatWithZeroRNG final : public UniformFloatRNG { | |||||
| class UniformFloatWithValueRNG : public UniformFloatRNG { | |||||
| public: | public: | ||||
| UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b, | |||||
| float zero_val_proportion) | |||||
| : UniformFloatRNG(a, b) { | |||||
| if (zero_val_proportion < 0.f) | |||||
| zero_val_proportion_ = 0.f; | |||||
| else if (zero_val_proportion > 1.f) | |||||
| zero_val_proportion_ = 1.f; | |||||
| UniformFloatWithValueRNG(dt_float32 a, dt_float32 b, float val_proportion, | |||||
| float val) | |||||
| : UniformFloatRNG(a, b), val_(val) { | |||||
| if (val_proportion < 0.f) | |||||
| val_proportion_ = 0.f; | |||||
| else if (val_proportion > 1.f) | |||||
| val_proportion_ = 1.f; | |||||
| else | else | ||||
| zero_val_proportion_ = zero_val_proportion; | |||||
| val_proportion_ = val_proportion; | |||||
| } | } | ||||
| private: | private: | ||||
| float zero_val_proportion_; | |||||
| float val_proportion_, val_; | |||||
| void fill_fast_float32(dt_float32* dest, size_t size) override; | void fill_fast_float32(dt_float32* dest, size_t size) override; | ||||
| }; | }; | ||||
| class UniformFloatWithZeroRNG final : public UniformFloatWithValueRNG { | |||||
| public: | |||||
| UniformFloatWithZeroRNG(dt_float32 a, dt_float32 b, | |||||
| float zero_val_proportion) | |||||
| : UniformFloatWithValueRNG(a, b, zero_val_proportion, 0.f) {} | |||||
| }; | |||||
| class BernoulliRNG final : public IIDRNG { | class BernoulliRNG final : public IIDRNG { | ||||
| public: | public: | ||||
| BernoulliRNG(dt_float32 probability_); | BernoulliRNG(dt_float32 probability_); | ||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * \file dnn/test/cuda/check_has_inf.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "megdnn/oprs.h" | |||||
| #include "test/common/checker.h" | |||||
| #include "test/cuda/fixture.h" | |||||
| namespace megdnn { | |||||
| namespace test { | |||||
| TEST_F(CUDA, CHECK_HAS_INF_BASIC) { | |||||
| Checker<CheckHasInf> checker(handle_cuda()); | |||||
| checker.set_allow_invalid_check(true); | |||||
| const auto inf = std::numeric_limits<float>::infinity(); | |||||
| UniformFloatWithValueRNG rng(-1.0f, 1.0f, 0.1f, inf); | |||||
| checker.set_rng(0, &rng); | |||||
| checker.execs({{512*16}, {1}}); | |||||
| rng = UniformFloatWithValueRNG(-1.0f, 1.0f, 1.f, inf); | |||||
| checker.set_rng(0, &rng); | |||||
| checker.execs({{512*16}, {1}}); | |||||
| } | |||||
| } // namespace test | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * \file test/naive/check_has_inf.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "test/naive/fixture.h" | |||||
| #include "megdnn/oprs.h" | |||||
| #include "test/common/checker.h" | |||||
| namespace megdnn { | |||||
| namespace test { | |||||
| TEST_F(NAIVE, CHECK_HAS_INF_BASIC) { | |||||
| Checker<CheckHasInf> checker(handle(), false); | |||||
| checker.exect(Testcase{TensorValue({4}, dtype::Float32(), | |||||
| {1.1, 2.2, 3.3, 4.3}), | |||||
| {}}, | |||||
| Testcase{{}, TensorValue({1}, dtype::Int32(), {0})}); | |||||
| checker.exect( | |||||
| Testcase{TensorValue({4}, dtype::Float32(), | |||||
| {1.1f, 2.2f, 3.3f, | |||||
| std::numeric_limits<float>::infinity()}), | |||||
| {}}, | |||||
| Testcase{{}, TensorValue({1}, dtype::Int32(), {1})}); | |||||
| } | |||||
| } // namespace test | |||||
| } // namespace megdnn | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -959,3 +959,16 @@ def svd(inp: Tensor, full_matrices=False, compute_uv=True) -> Tensor: | |||||
| op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | op = builtin.SVD(full_matrices=full_matrices, compute_uv=compute_uv) | ||||
| U, sigma, V = apply(op, inp) | U, sigma, V = apply(op, inp) | ||||
| return U, sigma, V | return U, sigma, V | ||||
| def _has_inf(inp: Tensor) -> Tensor: | |||||
| """ | |||||
| Check whether input contains infinite value. | |||||
| :param inp: a tensor to be checked. | |||||
| :return: a int32 scalar tensor, 0 for False and 1 for True. | |||||
| """ | |||||
| op = builtin.CheckHasInf() | |||||
| (oup,) = apply(op, inp.reshape(-1).astype("float32")) | |||||
| oup._setscalar() | |||||
| return oup | |||||
| @@ -157,3 +157,14 @@ def test_sum_neg_axis(): | |||||
| np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) | np.testing.assert_allclose(get.numpy(), ref, rtol=1e-6) | ||||
| with pytest.raises(AssertionError): | with pytest.raises(AssertionError): | ||||
| F.sum(tensor(data), axis=(-1, 1)) | F.sum(tensor(data), axis=(-1, 1)) | ||||
| def test_has_inf(): | |||||
| shape = (32, 3, 32, 32) | |||||
| data = np.random.random(shape).astype(np.float32) | |||||
| rst = F.math._has_inf(tensor(data)) | |||||
| np.testing.assert_equal(rst.numpy(), [0]) | |||||
| data[0][0][0][0] = float("inf") | |||||
| rst = F.math._has_inf(tensor(data)) | |||||
| np.testing.assert_equal(rst.numpy(), [1]) | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * \file imperative/src/impl/ops/tensor_manip.cpp | |||||
| * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| * | |||||
| * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | |||||
| * ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| */ | |||||
| #include "../op_trait.h" | |||||
| #include "megbrain/imperative/ops/autogen.h" | |||||
| #include "megbrain/opr/misc.h" | |||||
| namespace mgb { | |||||
| namespace imperative { | |||||
| namespace check_has_inf { | |||||
| auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { | |||||
| auto&& op = def.cast_final_safe<CheckHasInf>(); | |||||
| mgb_assert(inputs.size() == 1); | |||||
| OperatorNodeConfig config{op.make_name()}; | |||||
| return opr::CheckHasInf::make(inputs[0], {}, config); | |||||
| } | |||||
| OP_TRAIT_REG(CheckHasInf, CheckHasInf) | |||||
| .apply_on_var_node(apply_on_var_node) | |||||
| .fallback(); | |||||
| } // namespace check_has_inf | |||||
| } // namespace imperative | |||||
| } // namespace mgb | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | |||||
| @@ -307,4 +307,6 @@ def CambriconRuntime: MgbHashableOp<"CambriconRuntime"> { | |||||
| def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | def CvtColor: MgbHashableOp<"CvtColor", [CvtColorParam]>; | ||||
| def CheckHasInf: MgbHashableOp<"CheckHasInf", [EmptyParam]>; | |||||
| #endif // MGB_OPS | #endif // MGB_OPS | ||||
| @@ -437,4 +437,19 @@ MGB_IMPL_OPR_GRAD(TopK) { | |||||
| } | } | ||||
| #endif | #endif | ||||
| /* ================= CheckHasInf ================= */ | |||||
| namespace mgb { | |||||
| namespace opr { | |||||
| namespace intl { | |||||
| template<> | |||||
| struct MegDNNOprInitPostCtor<CheckHasInf> { | |||||
| static void apply(cg::OperatorNodeBase &opr) { | |||||
| opr.output(0)->dtype(dtype::Int32()); | |||||
| } | |||||
| }; | |||||
| } | |||||
| } | |||||
| } | |||||
| MGB_DYN_TYPE_OBJ_FINAL_IMPL(CheckHasInf); | |||||
| MEGDNN_OPR_INIT1(CheckHasInf, "check_has_inf") | |||||
| // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} | ||||
| @@ -73,6 +73,7 @@ namespace opr { | |||||
| #if MGB_CUDA | #if MGB_CUDA | ||||
| MGB_SEREG_OPR(NvOf, 1); | MGB_SEREG_OPR(NvOf, 1); | ||||
| #endif | #endif | ||||
| MGB_SEREG_OPR(CheckHasInf, 1); | |||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||
| @@ -178,6 +178,8 @@ public: | |||||
| const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
| }; | }; | ||||
| MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(CheckHasInf); | |||||
| } // namespace opr | } // namespace opr | ||||
| } // namespace mgb | } // namespace mgb | ||||