Browse Source

!25078 [assistant][ops] Add Cross

Merge pull request !25078 from 冯钰/cross
feature/build-system-rewrite
i-robot Gitee 4 years ago
parent
commit
bff5afe2bb
No known key found for this signature in database GPG Key ID: 173E9B9CA92EEF8F
11 changed files with 544 additions and 2 deletions
  1. +174
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cross_cpu_kernel.cc
  2. +110
    -0
      mindspore/ccsrc/backend/kernel_compiler/cpu/cross_cpu_kernel.h
  3. +4
    -0
      mindspore/core/base/core_ops.h
  4. +102
    -0
      mindspore/core/ops/cross.cc
  5. +44
    -0
      mindspore/core/ops/cross.h
  6. +11
    -0
      mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py
  7. +1
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py
  8. +42
    -0
      mindspore/python/mindspore/ops/_op_impl/aicpu/cross.py
  9. +3
    -2
      mindspore/python/mindspore/ops/operations/__init__.py
  10. +48
    -0
      mindspore/python/mindspore/ops/operations/math_ops.py
  11. +5
    -0
      tests/ut/python/ops/test_ops.py

+ 174
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cross_cpu_kernel.cc View File

@@ -0,0 +1,174 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/cross_cpu_kernel.h"
#include "runtime/device/cpu/cpu_device_address.h"

namespace {
const size_t kDataSizeThreshold = 4 * 1024;

#define CROSS_COMPUTE_CASE(DTYPE, TYPE) \
case (DTYPE): { \
ret = LaunchKernel<TYPE>(inputs, outputs); \
break; \
}
} // namespace

namespace mindspore {
namespace kernel {
void CrossCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
input1_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
input2_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
input1_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
dim_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "dim");
int64_t default_dim = -65530;
if (dim_ == default_dim) {
size_t dim_size_value = 3;
for (size_t i = 0; i < input1_shape_.size(); i++) {
if (input1_shape_[i] == dim_size_value) {
dim_ = i;
break;
}
if (i == input1_shape_.size() - 1 && input1_shape_[i] != dim_size_value) {
MS_EXCEPTION(ValueError) << "The size of inputs dim should be 3,but got" << input1_shape_[i];
}
}
}
if (dim_ < 0) {
dim_ = static_cast<int64_t>(input1_shape_.size()) + dim_;
}
}

bool CrossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
bool ret = true;
switch (input1_dtype_) {
CROSS_COMPUTE_CASE(kNumberTypeInt8, int8_t)
CROSS_COMPUTE_CASE(kNumberTypeInt16, int16_t)
CROSS_COMPUTE_CASE(kNumberTypeInt32, int32_t)
CROSS_COMPUTE_CASE(kNumberTypeInt64, int64_t)
CROSS_COMPUTE_CASE(kNumberTypeUInt8, uint8_t)
CROSS_COMPUTE_CASE(kNumberTypeUInt16, uint16_t)
CROSS_COMPUTE_CASE(kNumberTypeUInt32, uint32_t)
CROSS_COMPUTE_CASE(kNumberTypeUInt64, uint64_t)
CROSS_COMPUTE_CASE(kNumberTypeFloat16, float16)
CROSS_COMPUTE_CASE(kNumberTypeFloat32, float)
CROSS_COMPUTE_CASE(kNumberTypeFloat64, double)
CROSS_COMPUTE_CASE(kNumberTypeComplex64, std::complex<float>)
CROSS_COMPUTE_CASE(kNumberTypeComplex128, std::complex<double>)
default:
MS_EXCEPTION(TypeError) << "Unsupported input data type: " << input1_dtype_;
ret = false;
}
return ret;
}

template <typename T>
bool CrossCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &outputs) {
auto input1_data_addr = reinterpret_cast<T *>(inputs[0]->addr);
size_t tmp = 1;
for (size_t i = 0; i < input1_shape_.size(); i++) {
tmp = tmp * input1_shape_[i];
}
size_t input1_data_num = tmp;
auto input2_data_addr = reinterpret_cast<T *>(inputs[1]->addr);
auto output_data_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t total = input1_data_num / 3;
const size_t n = input1_shape_.size();
std::vector<size_t> a_stride(n);
size_t stride_tmp = 1;
for (int64_t i = static_cast<int64_t>(n - 1); i >= 0; i--) {
a_stride[LongToSize(i)] = stride_tmp;
stride_tmp *= input1_shape_[i];
}
size_t input1_data_stride = a_stride[dim_];
std::vector<size_t> b_stride(n);
stride_tmp = 1;
for (int64_t i = static_cast<int64_t>(n - 1); i >= 0; i--) {
b_stride[LongToSize(i)] = stride_tmp;
stride_tmp *= input2_shape_[i];
}
size_t input2_data_stride = b_stride[dim_];
std::vector<size_t> r_stride(n);
stride_tmp = 1;
for (int64_t i = static_cast<int64_t>(n - 1); i >= 0; i--) {
r_stride[LongToSize(i)] = stride_tmp;
stride_tmp *= output_shape_[i];
}
size_t output_data_stride = r_stride[dim_];
const int64_t pos = 2;
auto cross_shard = [this, &a_stride, &b_stride, &r_stride, &output_data_addr, &input1_data_addr, &input2_data_addr,
&output_data_stride, &input1_data_stride, &input2_data_stride](size_t start, size_t end) {
const size_t input1_data_dim = input1_shape_.size();
std::vector<size_t> position_in_dims(input1_data_dim);
size_t index_in_curr_dim = start;
size_t input1_data_start = 0;
size_t input2_data_start = 0;
size_t output_data_start = 0;
for (int64_t i = 0; i < static_cast<int64_t>(input1_data_dim); i++) {
if (i == static_cast<int64_t>(dim_)) continue;
position_in_dims[i] = index_in_curr_dim % input1_shape_[i];
input1_data_start += (index_in_curr_dim % input1_shape_[i]) * a_stride[i];
input2_data_start += (index_in_curr_dim % input2_shape_[i]) * b_stride[i];
output_data_start += (index_in_curr_dim % output_shape_[i]) * r_stride[i];
index_in_curr_dim = index_in_curr_dim / input1_shape_[i];
}
while (start < end) {
output_data_addr[output_data_start + 0 * output_data_stride] =
input1_data_addr[input1_data_start + 1 * input1_data_stride] *
input2_data_addr[input2_data_start + pos * input2_data_stride] -
input1_data_addr[input1_data_start + pos * input1_data_stride] *
input2_data_addr[input2_data_start + 1 * input2_data_stride];
output_data_addr[output_data_start + 1 * output_data_stride] =
input1_data_addr[input1_data_start + pos * input1_data_stride] *
input2_data_addr[input2_data_start + 0 * input2_data_stride] -
input1_data_addr[input1_data_start + 0 * input1_data_stride] *
input2_data_addr[input2_data_start + pos * input2_data_stride];
output_data_addr[output_data_start + pos * output_data_stride] =
input1_data_addr[input1_data_start + 0 * input1_data_stride] *
input2_data_addr[input2_data_start + 1 * input2_data_stride] -
input1_data_addr[input1_data_start + 1 * input1_data_stride] *
input2_data_addr[input2_data_start + 0 * input2_data_stride];
start++;
for (int64_t i = 0; i < static_cast<int64_t>(input1_data_dim); i++) {
if (i == static_cast<int64_t>(dim_)) {
continue;
}
position_in_dims[i]++;
input1_data_start += a_stride[i];
input2_data_start += b_stride[i];
output_data_start += r_stride[i];
if (position_in_dims[i] == input1_shape_[i] && i != static_cast<int64_t>(input1_shape_.size()) - 1) {
input1_data_start -= position_in_dims[i] * a_stride[i];
input2_data_start -= position_in_dims[i] * b_stride[i];
output_data_start -= position_in_dims[i] * r_stride[i];
position_in_dims[i] = 0;
} else {
break;
}
}
}
};
if (total * sizeof(T) < kDataSizeThreshold) {
cross_shard(0, total);
} else {
CPUKernelUtils::ParallelFor(cross_shard, total);
}
return true;
}
} // namespace kernel
} // namespace mindspore

+ 110
- 0
mindspore/ccsrc/backend/kernel_compiler/cpu/cross_cpu_kernel.h View File

@@ -0,0 +1,110 @@
/**
* Copyright 2021-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CROSS_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CROSS_CPU_KERNEL_H_

#include <vector>
#include <memory>
#include <string>
#include <complex>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"

namespace mindspore {
namespace kernel {
class CrossCpuKernelMod : public NativeCpuKernelMod {
public:
CrossCpuKernelMod() = default;
~CrossCpuKernelMod() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);

private:
std::vector<size_t> input1_shape_;
std::vector<size_t> input2_shape_;
std::vector<size_t> output_shape_;
int64_t dim_;
TypeId input1_dtype_;
};

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(
Cross,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(Cross,
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeComplex64)
.AddOutputAttr(kNumberTypeComplex64),
CrossCpuKernelMod);

MS_REG_CPU_KERNEL(Cross,
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeComplex128)
.AddOutputAttr(kNumberTypeComplex128),
CrossCpuKernelMod);
} // namespace kernel
} // namespace mindspore

#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CROSS_CPU_KERNEL_H_

+ 4
- 0
mindspore/core/base/core_ops.h View File

@@ -83,6 +83,9 @@ constexpr auto kImag = "Imag";
constexpr auto kConj = "Conj";
constexpr auto kGer = "Ger";

// Math
constexpr auto kCross = "Cross";

// Arrays
constexpr auto kDynamicShape = "DynamicShape";
constexpr auto kStack = "Stack";
@@ -596,6 +599,7 @@ MS_CORE_API inline const PrimitivePtr kPrimTensorListStack = std::make_shared<Pr
MS_CORE_API inline const PrimitivePtr kPrimTensorListSetItem = std::make_shared<Primitive>("TensorListSetItem");

// Maths
MS_CORE_API inline const PrimitivePtr kPrimCross = std::make_shared<Primitive>(kCross);
MS_CORE_API inline const PrimitivePtr kPrimBesselI0 = std::make_shared<Primitive>("BesselI0");
MS_CORE_API inline const PrimitivePtr kPrimBesselI1 = std::make_shared<Primitive>("BesselI1");
MS_CORE_API inline const PrimitivePtr kPrimGer = std::make_shared<Primitive>("Ger");


+ 102
- 0
mindspore/core/ops/cross.cc View File

@@ -0,0 +1,102 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ops/cross.h"
#include <string>
#include <algorithm>
#include <memory>
#include <set>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/primitive_infer_map.h"

namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr CrossInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto x2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dim = GetValue<int64_t>(primitive->GetAttr("dim"));
if (x1_shape.size() != x2_shape.size()) {
MS_EXCEPTION(ValueError) << "The shape of two inputs must have the same size.";
}
for (size_t i = 0; i < x1_shape.size(); ++i) {
if (x1_shape[i] != x2_shape[i]) {
MS_EXCEPTION(ValueError) << "x1 and x2 must have the same shape.";
}
}
if (x1_shape.size() <= 0 || x2_shape.size() <= 0) {
MS_EXCEPTION(ValueError) << "Inputs should not be a " << x1_shape.size() << " dimensional tensor.";
}
int64_t default_dim = -65530;
if (dim == default_dim) {
int64_t dim_size_value = 3;
for (size_t i = 0; i < x1_shape.size(); i++) {
if (x1_shape[i] == dim_size_value) {
dim = i;
break;
}
if (i == x1_shape.size() - 1 && x1_shape[i] != dim_size_value) {
MS_EXCEPTION(ValueError) << "The size of inputs dim should be 3,but got " << x1_shape[i];
}
}
}
if ((dim < -static_cast<int64_t>(x1_shape.size()) || dim > static_cast<int64_t>(x1_shape.size()) - 1) &&
dim != default_dim) {
MS_EXCEPTION(ValueError) << "dim should be between " << -static_cast<int64_t>(x1_shape.size()) << " and "
<< static_cast<int64_t>(x1_shape.size()) - 1 << " ,but got " << dim;
}
if (dim < 0 && dim != default_dim) {
dim = static_cast<int64_t>(x1_shape.size()) + dim;
}
int64_t dim_size = 3;
if (x1_shape[dim] != dim_size && x2_shape[dim] != dim_size && dim != default_dim) {
MS_EXCEPTION(ValueError) << "The size of inputs dim should be 3,but got " << x1_shape[dim];
}
return std::make_shared<abstract::Shape>(x1_shape);
}

TypePtr CrossInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInteger("Cross infer", input_args.size(), kEqual, input_num, op_name);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kFloat16, kFloat32,
kFloat64, kUInt16, kUInt32, kUInt64, kComplex64, kComplex128};
auto x1_type = input_args[0]->BuildType();
auto x2_type = input_args[1]->BuildType();
auto tensor_type = x2_type->cast<TensorTypePtr>();
auto element = tensor_type->element();
CheckAndConvertUtils::CheckTensorTypeValid("x2", x2_type, valid_types, primitive->name());
return CheckAndConvertUtils::CheckTensorTypeValid("x1", x1_type, {element}, primitive->name());
}
AbstractBasePtr CrossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = CrossInferType(primitive, input_args);
auto infer_shape = CrossInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}

REGISTER_PRIMITIVE_EVAL_IMPL(Cross, prim::kPrimCross, CrossInfer, nullptr, true);
} // namespace
} // namespace ops
} // namespace mindspore

+ 44
- 0
mindspore/core/ops/cross.h View File

@@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef MINDSPORE_CORE_OPS_CROSS_H_
#define MINDSPORE_CORE_OPS_CROSS_H_
#include <map>
#include <vector>
#include <string>
#include <memory>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"

namespace mindspore {
namespace ops {
constexpr auto kNameCross = "Cross";

class Cross : public PrimitiveC {
public:
Cross() : PrimitiveC(kNameCross) { InitIOName({"x1", "x2"}, {"y"}); }
~Cross() = default;
MS_DECLARE_PARENT(Cross, PrimitiveC);
};

AbstractBasePtr CrossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimCrossPtr = std::shared_ptr<Cross>;
} // namespace ops
} // namespace mindspore

#endif // MINDSPORE_CORE_OPS_CROSS_H_

+ 11
- 0
mindspore/python/mindspore/ops/_grad_experimental/grad_math_ops.py View File

@@ -312,3 +312,14 @@ def get_bprop_ger(self):
return dx, dy

return bprop


@bprop_getters.register(P.Cross)
def get_bprop_cross(self):
"""Grad definition for 'Cross' operation"""
cross = P.Cross(dim=self.dim)

def bprop(input1, input2, out, dout):
return cross(input2, dout), cross(dout, input1)

return bprop

+ 1
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/__init__.py View File

@@ -99,3 +99,4 @@ from .square import _square_aicpu
from .lower_bound import _lower_bound_aicpu
from .grid_sampler_3d import _grid_sampler_3d_aicpu
from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu
from .cross import _cross_aicpu

+ 42
- 0
mindspore/python/mindspore/ops/_op_impl/aicpu/cross.py View File

@@ -0,0 +1,42 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Cross op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType

cross_op_info = AiCPURegOp("Cross") \
.fusion_type("OPAQUE") \
.attr("dim", "int")\
.input(0, "x1", "required") \
.input(1, "x2", "required") \
.output(0, "y", "required") \
.dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \
.dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \
.dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \
.dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \
.dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \
.dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \
.dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
.dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \
.dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default)\
.get_op_info()

@op_info_register(cross_op_info)
def _cross_aicpu():
"""Cross aicpu register"""
return

+ 3
- 2
mindspore/python/mindspore/ops/operations/__init__.py View File

@@ -53,8 +53,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
BitwiseAnd, BitwiseOr, Ger,
BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub,
ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny,
Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil,
Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod,
Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod,
LogicalNot, LogicalOr, LpNorm, MatMul, Maximum, MulNoNan,
MatrixDeterminant, LogMatrixDeterminant, Minimum, Mul, Neg, NMSWithMask, NotEqual,
NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, Einsum,
@@ -518,6 +518,7 @@ __all__ = [
]

__sponge__ = [
"Cross",
"BondForce",
"BondEnergy",
"BondAtomEnergy",


+ 48
- 0
mindspore/python/mindspore/ops/operations/math_ops.py View File

@@ -5978,3 +5978,51 @@ class CholeskyInverse(Primitive):
"""Initialize CholeskyInverse"""
validator.check_value_type("upper", upper, [bool], self.name)
self.upper = upper


class Cross(Primitive):
"""
Returns the cross product of vectors in dimension `dim` of x1 and x2.
x1 and x2 must have the same shape and the same type, and the size of their `dim` dimension should be 3.
If `dim` is not given, it defaults to the first dimension found with the size 3.

Args:
dim (int): The default value is -65530.

Inputs:
- **x1** (Tensor) - x1 is a tensor.
x1 and x2 must have the same shape and the same type, and the size of their `dim` dimension should be 3.
- **x2** (Tensor) - x2 is a tensor.

Outputs:
Tensor, has the same shape and type as input.

Raises:
TypeError: If `x1` is not a Tensor.
TypeError: If `x2` is not a Tensor.
TypeError: If the type of `x1` is not the same as that of `x2`.
ValueError: If `x1` and `x2` not have the same size, and the size of their `dim` dimension not be 3.
ValueError: If `x1` and `x2` not have the same shape.
ValueError: If `dim` is out of range, `dim` should be [-len(x1.shape), len(x1.shape)-1].

Supported Platforms:
``CPU``

Examples:
>>> import mindspore
>>> import numpy as np
>>> from mindspore import Tensor
>>> from mindspore.common import dtype as mstype
>>> import mindspore.ops as ops
>>> cross = ops.Cross(dim = 0)
>>> x1 = Tensor([1, 2, 3], mstype.int8)
>>> x2 = Tensor([1, 2, 3], mstype.int8)
>>> output = cross(x1, x2)
>>> print(output)
[0, 0, 0]
"""

@prim_attr_register
def __init__(self, dim=-65530):
validator.check_value_type('dim', dim, [int], self.name)
self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y'])

+ 5
- 0
tests/ut/python/ops/test_ops.py View File

@@ -1084,6 +1084,11 @@ class ApplyKerasMomentumNet(nn.Cell):


test_case_math_ops = [
('Cross', {
'block': P.Cross(dim=1),
'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8),
Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)],
'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)]}),
('Ger', {
'block': P.Ger(),
'desc_inputs': [[3,], [4,]],


Loading…
Cancel
Save