| @@ -102,6 +102,29 @@ void ArithmeticCPUKernel::Div(const T *input1, const T *input2, T *out, size_t s | |||||
| } | } | ||||
| } | } | ||||
| template <typename T> | |||||
| void ArithmeticCPUKernel::FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end) { | |||||
| for (size_t i = start; i < end; i++) { | |||||
| std::vector<size_t> idx; | |||||
| GenIndex(i, &idx); | |||||
| auto dividend = input1[idx[0]]; | |||||
| auto divisor = input2[idx[1]]; | |||||
| if (divisor == 0) { | |||||
| if (dividend == 0) { | |||||
| out[i] = std::numeric_limits<T>::quiet_NaN(); | |||||
| continue; | |||||
| } | |||||
| if (std::numeric_limits<T>::has_infinity) { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::infinity() : -std::numeric_limits<T>::infinity(); | |||||
| } else { | |||||
| out[i] = dividend > 0 ? std::numeric_limits<T>::max() : std::numeric_limits<T>::min(); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| out[i] = floor(dividend / divisor); | |||||
| } | |||||
| } | |||||
| template <typename T> | template <typename T> | ||||
| void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t start, size_t end) { | void ArithmeticCPUKernel::Mod(const T *input1, const T *input2, T *out, size_t start, size_t end) { | ||||
| for (size_t i = start; i < end; i++) { | for (size_t i = start; i < end; i++) { | ||||
| @@ -207,6 +230,8 @@ void ArithmeticCPUKernel::InitKernel(const CNodePtr &kernel_node) { | |||||
| operate_type_ = REALDIV; | operate_type_ = REALDIV; | ||||
| } else if (kernel_name == prim::kPrimDiv->name()) { | } else if (kernel_name == prim::kPrimDiv->name()) { | ||||
| operate_type_ = DIV; | operate_type_ = DIV; | ||||
| } else if (kernel_name == prim::kPrimFloorDiv->name()) { | |||||
| operate_type_ = FLOORDIV; | |||||
| } else if (kernel_name == prim::kPrimMod->name()) { | } else if (kernel_name == prim::kPrimMod->name()) { | ||||
| operate_type_ = MOD; | operate_type_ = MOD; | ||||
| } else if (kernel_name == prim::kPrimPow->name()) { | } else if (kernel_name == prim::kPrimPow->name()) { | ||||
| @@ -389,6 +414,8 @@ void ArithmeticCPUKernel::LaunchKernel(const std::vector<AddressPtr> &inputs, co | |||||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::RealDiv<T>, this, input1, input2, output, start, end)); | threads.emplace_back(std::thread(&ArithmeticCPUKernel::RealDiv<T>, this, input1, input2, output, start, end)); | ||||
| } else if (operate_type_ == DIV) { | } else if (operate_type_ == DIV) { | ||||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div<T>, this, input1, input2, output, start, end)); | threads.emplace_back(std::thread(&ArithmeticCPUKernel::Div<T>, this, input1, input2, output, start, end)); | ||||
| } else if (operate_type_ == FLOORDIV) { | |||||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::FloorDiv<T>, this, input1, input2, output, start, end)); | |||||
| } else if (operate_type_ == MOD) { | } else if (operate_type_ == MOD) { | ||||
| threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mod<T>, this, input1, input2, output, start, end)); | threads.emplace_back(std::thread(&ArithmeticCPUKernel::Mod<T>, this, input1, input2, output, start, end)); | ||||
| } else if (operate_type_ == POW) { | } else if (operate_type_ == POW) { | ||||
| @@ -50,6 +50,8 @@ class ArithmeticCPUKernel : public CPUKernel { | |||||
| template <typename T> | template <typename T> | ||||
| void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); | void Div(const T *input1, const T *input2, T *out, size_t start, size_t end); | ||||
| template <typename T> | template <typename T> | ||||
| void FloorDiv(const T *input1, const T *input2, T *out, size_t start, size_t end); | |||||
| template <typename T> | |||||
| void Mod(const T *input1, const T *input2, T *out, size_t start, size_t end); | void Mod(const T *input1, const T *input2, T *out, size_t start, size_t end); | ||||
| template <typename T> | template <typename T> | ||||
| void Pow(const T *input1, const T *input2, T *out, size_t start, size_t end); | void Pow(const T *input1, const T *input2, T *out, size_t start, size_t end); | ||||
| @@ -117,6 +119,16 @@ MS_REG_CPU_KERNEL( | |||||
| MS_REG_CPU_KERNEL( | MS_REG_CPU_KERNEL( | ||||
| Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | ||||
| ArithmeticCPUKernel); | ArithmeticCPUKernel); | ||||
| MS_REG_CPU_KERNEL( | |||||
| FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||||
| ArithmeticCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||||
| ArithmeticCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | |||||
| FloorDiv, | |||||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||||
| ArithmeticCPUKernel); | |||||
| MS_REG_CPU_KERNEL( | MS_REG_CPU_KERNEL( | ||||
| Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | Mod, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | ||||
| ArithmeticCPUKernel); | ArithmeticCPUKernel); | ||||
| @@ -67,6 +67,7 @@ enum OperateType { | |||||
| SQRT, | SQRT, | ||||
| POW, | POW, | ||||
| REALDIV, | REALDIV, | ||||
| FLOORDIV, | |||||
| MOD, | MOD, | ||||
| NEG, | NEG, | ||||
| LESS, | LESS, | ||||
| @@ -261,6 +261,7 @@ inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("Inplace | |||||
| inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub"); | ||||
| inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow"); | ||||
| inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv"); | ||||
| inline const PrimitivePtr kPrimFloorDiv = std::make_shared<Primitive>("FloorDiv"); | |||||
| inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt"); | ||||
| inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"); | inline const PrimitivePtr kPrimSqrtGrad = std::make_shared<Primitive>("SqrtGrad"); | ||||
| inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal"); | ||||
| @@ -42,6 +42,15 @@ class DivNet(nn.Cell): | |||||
| return self.div(x, y) | return self.div(x, y) | ||||
| class FloorDivNet(nn.Cell): | |||||
| def __init__(self): | |||||
| super(FloorDivNet, self).__init__() | |||||
| self.floor_div = P.FloorDiv() | |||||
| def construct(self, x, y): | |||||
| return self.floor_div(x, y) | |||||
| class ModNet(nn.Cell): | class ModNet(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(ModNet, self).__init__() | super(ModNet, self).__init__() | ||||
| @@ -156,6 +165,71 @@ def test_div(): | |||||
| assert output7.shape == expect7.shape | assert output7.shape == expect7.shape | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_x86_cpu_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_floor_div(): | |||||
| prop = 1 if np.random.random() < 0.5 else -1 | |||||
| x0_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop | |||||
| y0_np = np.random.randint(1, 100, (2, 1, 4, 4)).astype(np.float32) * prop | |||||
| x1_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.float16) * prop | |||||
| y1_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float16) * prop | |||||
| x2_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int32) * prop | |||||
| y2_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop | |||||
| x3_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int32) * prop | |||||
| y3_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.float32) * prop | |||||
| x4_np = np.random.randint(1, 100, (2, 1, 1, 4)).astype(np.int64) * prop | |||||
| y4_np = np.random.randint(1, 100, (2, 3, 4, 4)).astype(np.int64) * prop | |||||
| x0 = Tensor(x0_np) | |||||
| y0 = Tensor(y0_np) | |||||
| x1 = Tensor(x1_np) | |||||
| y1 = Tensor(y1_np) | |||||
| x2 = Tensor(x2_np) | |||||
| y2 = Tensor(y2_np) | |||||
| x3 = Tensor(x3_np) | |||||
| y3 = Tensor(y3_np) | |||||
| x4 = Tensor(x4_np) | |||||
| y4 = Tensor(y4_np) | |||||
| context.set_context(mode=context.GRAPH_MODE, device_target='CPU') | |||||
| floor_div = FloorDivNet() | |||||
| output0 = floor_div(x0, y0) | |||||
| expect0 = np.floor_divide(x0_np, y0_np) | |||||
| diff0 = output0.asnumpy() - expect0 | |||||
| error0 = np.ones(shape=expect0.shape) * 1.0e-5 | |||||
| assert np.all(diff0 < error0) | |||||
| assert output0.shape == expect0.shape | |||||
| output1 = floor_div(x1, y1) | |||||
| expect1 = np.floor_divide(x1_np, y1_np) | |||||
| diff1 = output1.asnumpy() - expect1 | |||||
| error1 = np.ones(shape=expect1.shape) * 1.0e-5 | |||||
| assert np.all(diff1 < error1) | |||||
| assert output1.shape == expect1.shape | |||||
| output2 = floor_div(x2, y2) | |||||
| expect2 = np.floor_divide(x2_np, y2_np).astype(np.float16) | |||||
| diff2 = output2.asnumpy() - expect2 | |||||
| error2 = np.ones(shape=expect2.shape) * 1.0e-5 | |||||
| assert np.all(diff2 < error2) | |||||
| assert output2.shape == expect2.shape | |||||
| output3 = floor_div(x3, y3) | |||||
| expect3 = np.floor_divide(x3_np, y3_np) | |||||
| diff3 = output3.asnumpy() - expect3 | |||||
| error3 = np.ones(shape=expect3.shape) * 1.0e-5 | |||||
| assert np.all(diff3 < error3) | |||||
| assert output3.shape == expect3.shape | |||||
| output4 = floor_div(x4, y4) | |||||
| expect4 = np.floor_divide(x4_np, y4_np) | |||||
| diff4 = output4.asnumpy() - expect4 | |||||
| error4 = np.ones(shape=expect4.shape) * 1.0e-5 | |||||
| assert np.all(diff4 < error4) | |||||
| assert output4.shape == expect4.shape | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_x86_cpu_training | @pytest.mark.platform_x86_cpu_training | ||||
| @pytest.mark.env_onecard | @pytest.mark.env_onecard | ||||
| @@ -249,6 +323,8 @@ def test_mod(): | |||||
| assert np.all(output7.asnumpy() == expect7) | assert np.all(output7.asnumpy() == expect7) | ||||
| assert output6.shape == expect6.shape | assert output6.shape == expect6.shape | ||||
| test_sub() | test_sub() | ||||
| test_div() | test_div() | ||||
| test_floor_div() | |||||
| test_mod() | test_mod() | ||||