| @@ -39,6 +39,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, cudaStream_t stream) { | |||||
| #define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t); | #define INST(T) template void exec_internal<T>(T*, size_t, size_t, int, cudaStream_t); | ||||
| #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| } // namespace eye | } // namespace eye | ||||
| } // namespace cuda | } // namespace cuda | ||||
| @@ -26,6 +26,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
| cuda_stream(handle())); \ | cuda_stream(handle())); \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| @@ -31,6 +31,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
| MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \ | MEGDNN_DISPATCH_CPU_KERN_OPR(exec_internal<ctype>(dst.ptr<ctype>(), m, n)); \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| @@ -44,7 +44,7 @@ void exec_internal(T* dst, size_t m, size_t n, int k, hipStream_t stream) { | |||||
| template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t); | template void exec_internal<T>(T*, size_t, size_t, int, hipStream_t); | ||||
| #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | #define cb(DType) INST(typename DTypeTrait<DType>::ctype) | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| } // namespace eye | } // namespace eye | ||||
| } // namespace rocm | } // namespace rocm | ||||
| } // namespace megdnn | } // namespace megdnn | ||||
| @@ -27,6 +27,7 @@ void EyeImpl::exec(_megdnn_tensor_out dst, _megdnn_workspace workspace) { | |||||
| hip_stream(handle())); \ | hip_stream(handle())); \ | ||||
| } | } | ||||
| MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | MEGDNN_FOREACH_COMPUTING_DTYPE(cb) | ||||
| cb(::megdnn::dtype::Bool) | |||||
| #undef cb | #undef cb | ||||
| } | } | ||||
| @@ -24,21 +24,22 @@ from megengine.utils.network_node import VarNode | |||||
| def test_eye(): | def test_eye(): | ||||
| dtype = np.float32 | |||||
| dtypes = [np.float32, np.bool] | |||||
| cases = [{"input": [10, 20]}, {"input": [30]}] | cases = [{"input": [10, 20]}, {"input": [30]}] | ||||
| for case in cases: | |||||
| np.testing.assert_allclose( | |||||
| F.eye(case["input"], dtype=dtype).numpy(), | |||||
| np.eye(*case["input"]).astype(dtype), | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| F.eye(*case["input"], dtype=dtype).numpy(), | |||||
| np.eye(*case["input"]).astype(dtype), | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| F.eye(tensor(case["input"]), dtype=dtype).numpy(), | |||||
| np.eye(*case["input"]).astype(dtype), | |||||
| ) | |||||
| for dtype in dtypes: | |||||
| for case in cases: | |||||
| np.testing.assert_allclose( | |||||
| F.eye(case["input"], dtype=dtype).numpy(), | |||||
| np.eye(*case["input"]).astype(dtype), | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| F.eye(*case["input"], dtype=dtype).numpy(), | |||||
| np.eye(*case["input"]).astype(dtype), | |||||
| ) | |||||
| np.testing.assert_allclose( | |||||
| F.eye(tensor(case["input"]), dtype=dtype).numpy(), | |||||
| np.eye(*case["input"]).astype(dtype), | |||||
| ) | |||||
| def test_full(): | def test_full(): | ||||