/** * \file imperative/src/test/rng.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ #include "./helper.h" #include "megbrain/imperative/ops/rng.h" using namespace mgb; using namespace imperative; using namespace imperative::rng; template void check_rng_basic(Args&& ...args) { for (auto&& tshape: { TensorShape{2, 3, 4, 5}, {3, 4, 5, 6}, {2333}}) for (auto&& cn: { CompNode::load("xpu0"), CompNode::load("xpu1")}) { Handle h = new_handle(cn, 123); auto op = Op::make(std::forward(args)..., h); DeviceTensorND tshape_dev; cg::copy_shape_to_tensor_value(tshape_dev, tshape); SmallVector inputs = {Tensor::make(tshape_dev)}; auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); ASSERT_TRUE(cn == outputs[0]->comp_node()); // sync before delete handle for (auto&& p: outputs) { p->get_value(); } delete_handle(h); } } TEST(TestImperative, UniformRNGBasic) { REQUIRE_XPU(2); check_rng_basic(123); } TEST(TestImperative, GaussianRNGBasic) { REQUIRE_XPU(2); check_rng_basic(123, 2.f, 3.f); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}