Merge pull request !25078 from 冯钰/crossfeature/build-system-rewrite
| @@ -0,0 +1,174 @@ | |||
| /** | |||
| * Copyright 2021-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "backend/kernel_compiler/cpu/cross_cpu_kernel.h" | |||
| #include "runtime/device/cpu/cpu_device_address.h" | |||
| namespace { | |||
| const size_t kDataSizeThreshold = 4 * 1024; | |||
| #define CROSS_COMPUTE_CASE(DTYPE, TYPE) \ | |||
| case (DTYPE): { \ | |||
| ret = LaunchKernel<TYPE>(inputs, outputs); \ | |||
| break; \ | |||
| } | |||
| } // namespace | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| void CrossCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { | |||
| input1_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); | |||
| input2_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 1); | |||
| output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); | |||
| input1_dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); | |||
| dim_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "dim"); | |||
| int64_t default_dim = -65530; | |||
| if (dim_ == default_dim) { | |||
| size_t dim_size_value = 3; | |||
| for (size_t i = 0; i < input1_shape_.size(); i++) { | |||
| if (input1_shape_[i] == dim_size_value) { | |||
| dim_ = i; | |||
| break; | |||
| } | |||
| if (i == input1_shape_.size() - 1 && input1_shape_[i] != dim_size_value) { | |||
| MS_EXCEPTION(ValueError) << "The size of inputs dim should be 3,but got" << input1_shape_[i]; | |||
| } | |||
| } | |||
| } | |||
| if (dim_ < 0) { | |||
| dim_ = static_cast<int64_t>(input1_shape_.size()) + dim_; | |||
| } | |||
| } | |||
| bool CrossCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| bool ret = true; | |||
| switch (input1_dtype_) { | |||
| CROSS_COMPUTE_CASE(kNumberTypeInt8, int8_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeInt16, int16_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeInt32, int32_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeInt64, int64_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeUInt8, uint8_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeUInt16, uint16_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeUInt32, uint32_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeUInt64, uint64_t) | |||
| CROSS_COMPUTE_CASE(kNumberTypeFloat16, float16) | |||
| CROSS_COMPUTE_CASE(kNumberTypeFloat32, float) | |||
| CROSS_COMPUTE_CASE(kNumberTypeFloat64, double) | |||
| CROSS_COMPUTE_CASE(kNumberTypeComplex64, std::complex<float>) | |||
| CROSS_COMPUTE_CASE(kNumberTypeComplex128, std::complex<double>) | |||
| default: | |||
| MS_EXCEPTION(TypeError) << "Unsupported input data type: " << input1_dtype_; | |||
| ret = false; | |||
| } | |||
| return ret; | |||
| } | |||
| template <typename T> | |||
| bool CrossCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, | |||
| const std::vector<kernel::AddressPtr> &outputs) { | |||
| auto input1_data_addr = reinterpret_cast<T *>(inputs[0]->addr); | |||
| size_t tmp = 1; | |||
| for (size_t i = 0; i < input1_shape_.size(); i++) { | |||
| tmp = tmp * input1_shape_[i]; | |||
| } | |||
| size_t input1_data_num = tmp; | |||
| auto input2_data_addr = reinterpret_cast<T *>(inputs[1]->addr); | |||
| auto output_data_addr = reinterpret_cast<T *>(outputs[0]->addr); | |||
| size_t total = input1_data_num / 3; | |||
| const size_t n = input1_shape_.size(); | |||
| std::vector<size_t> a_stride(n); | |||
| size_t stride_tmp = 1; | |||
| for (int64_t i = static_cast<int64_t>(n - 1); i >= 0; i--) { | |||
| a_stride[LongToSize(i)] = stride_tmp; | |||
| stride_tmp *= input1_shape_[i]; | |||
| } | |||
| size_t input1_data_stride = a_stride[dim_]; | |||
| std::vector<size_t> b_stride(n); | |||
| stride_tmp = 1; | |||
| for (int64_t i = static_cast<int64_t>(n - 1); i >= 0; i--) { | |||
| b_stride[LongToSize(i)] = stride_tmp; | |||
| stride_tmp *= input2_shape_[i]; | |||
| } | |||
| size_t input2_data_stride = b_stride[dim_]; | |||
| std::vector<size_t> r_stride(n); | |||
| stride_tmp = 1; | |||
| for (int64_t i = static_cast<int64_t>(n - 1); i >= 0; i--) { | |||
| r_stride[LongToSize(i)] = stride_tmp; | |||
| stride_tmp *= output_shape_[i]; | |||
| } | |||
| size_t output_data_stride = r_stride[dim_]; | |||
| const int64_t pos = 2; | |||
| auto cross_shard = [this, &a_stride, &b_stride, &r_stride, &output_data_addr, &input1_data_addr, &input2_data_addr, | |||
| &output_data_stride, &input1_data_stride, &input2_data_stride](size_t start, size_t end) { | |||
| const size_t input1_data_dim = input1_shape_.size(); | |||
| std::vector<size_t> position_in_dims(input1_data_dim); | |||
| size_t index_in_curr_dim = start; | |||
| size_t input1_data_start = 0; | |||
| size_t input2_data_start = 0; | |||
| size_t output_data_start = 0; | |||
| for (int64_t i = 0; i < static_cast<int64_t>(input1_data_dim); i++) { | |||
| if (i == static_cast<int64_t>(dim_)) continue; | |||
| position_in_dims[i] = index_in_curr_dim % input1_shape_[i]; | |||
| input1_data_start += (index_in_curr_dim % input1_shape_[i]) * a_stride[i]; | |||
| input2_data_start += (index_in_curr_dim % input2_shape_[i]) * b_stride[i]; | |||
| output_data_start += (index_in_curr_dim % output_shape_[i]) * r_stride[i]; | |||
| index_in_curr_dim = index_in_curr_dim / input1_shape_[i]; | |||
| } | |||
| while (start < end) { | |||
| output_data_addr[output_data_start + 0 * output_data_stride] = | |||
| input1_data_addr[input1_data_start + 1 * input1_data_stride] * | |||
| input2_data_addr[input2_data_start + pos * input2_data_stride] - | |||
| input1_data_addr[input1_data_start + pos * input1_data_stride] * | |||
| input2_data_addr[input2_data_start + 1 * input2_data_stride]; | |||
| output_data_addr[output_data_start + 1 * output_data_stride] = | |||
| input1_data_addr[input1_data_start + pos * input1_data_stride] * | |||
| input2_data_addr[input2_data_start + 0 * input2_data_stride] - | |||
| input1_data_addr[input1_data_start + 0 * input1_data_stride] * | |||
| input2_data_addr[input2_data_start + pos * input2_data_stride]; | |||
| output_data_addr[output_data_start + pos * output_data_stride] = | |||
| input1_data_addr[input1_data_start + 0 * input1_data_stride] * | |||
| input2_data_addr[input2_data_start + 1 * input2_data_stride] - | |||
| input1_data_addr[input1_data_start + 1 * input1_data_stride] * | |||
| input2_data_addr[input2_data_start + 0 * input2_data_stride]; | |||
| start++; | |||
| for (int64_t i = 0; i < static_cast<int64_t>(input1_data_dim); i++) { | |||
| if (i == static_cast<int64_t>(dim_)) { | |||
| continue; | |||
| } | |||
| position_in_dims[i]++; | |||
| input1_data_start += a_stride[i]; | |||
| input2_data_start += b_stride[i]; | |||
| output_data_start += r_stride[i]; | |||
| if (position_in_dims[i] == input1_shape_[i] && i != static_cast<int64_t>(input1_shape_.size()) - 1) { | |||
| input1_data_start -= position_in_dims[i] * a_stride[i]; | |||
| input2_data_start -= position_in_dims[i] * b_stride[i]; | |||
| output_data_start -= position_in_dims[i] * r_stride[i]; | |||
| position_in_dims[i] = 0; | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| }; | |||
| if (total * sizeof(T) < kDataSizeThreshold) { | |||
| cross_shard(0, total); | |||
| } else { | |||
| CPUKernelUtils::ParallelFor(cross_shard, total); | |||
| } | |||
| return true; | |||
| } | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,110 @@ | |||
| /** | |||
| * Copyright 2021-2022 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CROSS_CPU_KERNEL_H_ | |||
| #define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CROSS_CPU_KERNEL_H_ | |||
| #include <vector> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <complex> | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel.h" | |||
| #include "backend/kernel_compiler/cpu/cpu_kernel_factory.h" | |||
| namespace mindspore { | |||
| namespace kernel { | |||
| class CrossCpuKernelMod : public NativeCpuKernelMod { | |||
| public: | |||
| CrossCpuKernelMod() = default; | |||
| ~CrossCpuKernelMod() override = default; | |||
| void InitKernel(const CNodePtr &kernel_node) override; | |||
| bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, | |||
| const std::vector<AddressPtr> &outputs) override; | |||
| template <typename T> | |||
| bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs); | |||
| private: | |||
| std::vector<size_t> input1_shape_; | |||
| std::vector<size_t> input2_shape_; | |||
| std::vector<size_t> output_shape_; | |||
| int64_t dim_; | |||
| TypeId input1_dtype_; | |||
| }; | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL( | |||
| Cross, | |||
| KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(Cross, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddInputAttr(kNumberTypeComplex64) | |||
| .AddOutputAttr(kNumberTypeComplex64), | |||
| CrossCpuKernelMod); | |||
| MS_REG_CPU_KERNEL(Cross, | |||
| KernelAttr() | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddInputAttr(kNumberTypeComplex128) | |||
| .AddOutputAttr(kNumberTypeComplex128), | |||
| CrossCpuKernelMod); | |||
| } // namespace kernel | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_CROSS_CPU_KERNEL_H_ | |||
| @@ -83,6 +83,9 @@ constexpr auto kImag = "Imag"; | |||
| constexpr auto kConj = "Conj"; | |||
| constexpr auto kGer = "Ger"; | |||
| // Math | |||
| constexpr auto kCross = "Cross"; | |||
| // Arrays | |||
| constexpr auto kDynamicShape = "DynamicShape"; | |||
| constexpr auto kStack = "Stack"; | |||
| @@ -596,6 +599,7 @@ MS_CORE_API inline const PrimitivePtr kPrimTensorListStack = std::make_shared<Pr | |||
| MS_CORE_API inline const PrimitivePtr kPrimTensorListSetItem = std::make_shared<Primitive>("TensorListSetItem"); | |||
| // Maths | |||
| MS_CORE_API inline const PrimitivePtr kPrimCross = std::make_shared<Primitive>(kCross); | |||
| MS_CORE_API inline const PrimitivePtr kPrimBesselI0 = std::make_shared<Primitive>("BesselI0"); | |||
| MS_CORE_API inline const PrimitivePtr kPrimBesselI1 = std::make_shared<Primitive>("BesselI1"); | |||
| MS_CORE_API inline const PrimitivePtr kPrimGer = std::make_shared<Primitive>("Ger"); | |||
| @@ -0,0 +1,102 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ops/cross.h" | |||
| #include <string> | |||
| #include <algorithm> | |||
| #include <memory> | |||
| #include <set> | |||
| #include <vector> | |||
| #include "ops/op_utils.h" | |||
| #include "utils/check_convert_utils.h" | |||
| #include "abstract/primitive_infer_map.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| namespace { | |||
| abstract::ShapePtr CrossInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| auto x1_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; | |||
| auto x2_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape]; | |||
| auto dim = GetValue<int64_t>(primitive->GetAttr("dim")); | |||
| if (x1_shape.size() != x2_shape.size()) { | |||
| MS_EXCEPTION(ValueError) << "The shape of two inputs must have the same size."; | |||
| } | |||
| for (size_t i = 0; i < x1_shape.size(); ++i) { | |||
| if (x1_shape[i] != x2_shape[i]) { | |||
| MS_EXCEPTION(ValueError) << "x1 and x2 must have the same shape."; | |||
| } | |||
| } | |||
| if (x1_shape.size() <= 0 || x2_shape.size() <= 0) { | |||
| MS_EXCEPTION(ValueError) << "Inputs should not be a " << x1_shape.size() << " dimensional tensor."; | |||
| } | |||
| int64_t default_dim = -65530; | |||
| if (dim == default_dim) { | |||
| int64_t dim_size_value = 3; | |||
| for (size_t i = 0; i < x1_shape.size(); i++) { | |||
| if (x1_shape[i] == dim_size_value) { | |||
| dim = i; | |||
| break; | |||
| } | |||
| if (i == x1_shape.size() - 1 && x1_shape[i] != dim_size_value) { | |||
| MS_EXCEPTION(ValueError) << "The size of inputs dim should be 3,but got " << x1_shape[i]; | |||
| } | |||
| } | |||
| } | |||
| if ((dim < -static_cast<int64_t>(x1_shape.size()) || dim > static_cast<int64_t>(x1_shape.size()) - 1) && | |||
| dim != default_dim) { | |||
| MS_EXCEPTION(ValueError) << "dim should be between " << -static_cast<int64_t>(x1_shape.size()) << " and " | |||
| << static_cast<int64_t>(x1_shape.size()) - 1 << " ,but got " << dim; | |||
| } | |||
| if (dim < 0 && dim != default_dim) { | |||
| dim = static_cast<int64_t>(x1_shape.size()) + dim; | |||
| } | |||
| int64_t dim_size = 3; | |||
| if (x1_shape[dim] != dim_size && x2_shape[dim] != dim_size && dim != default_dim) { | |||
| MS_EXCEPTION(ValueError) << "The size of inputs dim should be 3,but got " << x1_shape[dim]; | |||
| } | |||
| return std::make_shared<abstract::Shape>(x1_shape); | |||
| } | |||
| TypePtr CrossInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) { | |||
| auto op_name = primitive->name(); | |||
| const int64_t input_num = 2; | |||
| CheckAndConvertUtils::CheckInteger("Cross infer", input_args.size(), kEqual, input_num, op_name); | |||
| for (const auto &item : input_args) { | |||
| MS_EXCEPTION_IF_NULL(item); | |||
| } | |||
| const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kFloat16, kFloat32, | |||
| kFloat64, kUInt16, kUInt32, kUInt64, kComplex64, kComplex128}; | |||
| auto x1_type = input_args[0]->BuildType(); | |||
| auto x2_type = input_args[1]->BuildType(); | |||
| auto tensor_type = x2_type->cast<TensorTypePtr>(); | |||
| auto element = tensor_type->element(); | |||
| CheckAndConvertUtils::CheckTensorTypeValid("x2", x2_type, valid_types, primitive->name()); | |||
| return CheckAndConvertUtils::CheckTensorTypeValid("x1", x1_type, {element}, primitive->name()); | |||
| } | |||
| AbstractBasePtr CrossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args) { | |||
| MS_EXCEPTION_IF_NULL(primitive); | |||
| const int64_t input_num = 2; | |||
| CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name()); | |||
| auto infer_type = CrossInferType(primitive, input_args); | |||
| auto infer_shape = CrossInferShape(primitive, input_args); | |||
| return abstract::MakeAbstract(infer_shape, infer_type); | |||
| } | |||
| REGISTER_PRIMITIVE_EVAL_IMPL(Cross, prim::kPrimCross, CrossInfer, nullptr, true); | |||
| } // namespace | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef MINDSPORE_CORE_OPS_CROSS_H_ | |||
| #define MINDSPORE_CORE_OPS_CROSS_H_ | |||
| #include <map> | |||
| #include <vector> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "ops/primitive_c.h" | |||
| #include "abstract/abstract_value.h" | |||
| #include "utils/check_convert_utils.h" | |||
| namespace mindspore { | |||
| namespace ops { | |||
| constexpr auto kNameCross = "Cross"; | |||
| class Cross : public PrimitiveC { | |||
| public: | |||
| Cross() : PrimitiveC(kNameCross) { InitIOName({"x1", "x2"}, {"y"}); } | |||
| ~Cross() = default; | |||
| MS_DECLARE_PARENT(Cross, PrimitiveC); | |||
| }; | |||
| AbstractBasePtr CrossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, | |||
| const std::vector<AbstractBasePtr> &input_args); | |||
| using PrimCrossPtr = std::shared_ptr<Cross>; | |||
| } // namespace ops | |||
| } // namespace mindspore | |||
| #endif // MINDSPORE_CORE_OPS_CROSS_H_ | |||
| @@ -312,3 +312,14 @@ def get_bprop_ger(self): | |||
| return dx, dy | |||
| return bprop | |||
| @bprop_getters.register(P.Cross) | |||
| def get_bprop_cross(self): | |||
| """Grad definition for 'Cross' operation""" | |||
| cross = P.Cross(dim=self.dim) | |||
| def bprop(input1, input2, out, dout): | |||
| return cross(input2, dout), cross(dout, input1) | |||
| return bprop | |||
| @@ -99,3 +99,4 @@ from .square import _square_aicpu | |||
| from .lower_bound import _lower_bound_aicpu | |||
| from .grid_sampler_3d import _grid_sampler_3d_aicpu | |||
| from .grid_sampler_3d_grad import _grid_sampler_3d_grad_aicpu | |||
| from .cross import _cross_aicpu | |||
| @@ -0,0 +1,42 @@ | |||
| # Copyright 2021 Huawei Technologies Co., Ltd | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| """Cross op""" | |||
| from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType | |||
| cross_op_info = AiCPURegOp("Cross") \ | |||
| .fusion_type("OPAQUE") \ | |||
| .attr("dim", "int")\ | |||
| .input(0, "x1", "required") \ | |||
| .input(1, "x2", "required") \ | |||
| .output(0, "y", "required") \ | |||
| .dtype_format(DataType.I8_Default, DataType.I8_Default, DataType.I8_Default) \ | |||
| .dtype_format(DataType.I16_Default, DataType.I16_Default, DataType.I16_Default) \ | |||
| .dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default) \ | |||
| .dtype_format(DataType.I64_Default, DataType.I64_Default, DataType.I64_Default) \ | |||
| .dtype_format(DataType.U8_Default, DataType.U8_Default, DataType.U8_Default) \ | |||
| .dtype_format(DataType.U16_Default, DataType.U16_Default, DataType.U16_Default) \ | |||
| .dtype_format(DataType.U32_Default, DataType.U32_Default, DataType.U32_Default) \ | |||
| .dtype_format(DataType.U64_Default, DataType.U64_Default, DataType.U64_Default) \ | |||
| .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ | |||
| .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ | |||
| .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ | |||
| .dtype_format(DataType.C64_Default, DataType.C64_Default, DataType.C64_Default) \ | |||
| .dtype_format(DataType.C128_Default, DataType.C128_Default, DataType.C128_Default)\ | |||
| .get_op_info() | |||
| @op_info_register(cross_op_info) | |||
| def _cross_aicpu(): | |||
| """Cross aicpu register""" | |||
| return | |||
| @@ -53,8 +53,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A | |||
| BitwiseAnd, BitwiseOr, Ger, | |||
| BitwiseXor, Inv, Invert, ApproximateEqual, InplaceAdd, InplaceSub, | |||
| ReduceMax, ReduceMin, ReduceMean, ReduceSum, ReduceAll, ReduceProd, CumProd, Cdist, ReduceAny, | |||
| Cos, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, Ceil, | |||
| Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod, | |||
| Cos, Cross, Div, DivNoNan, Equal, EqualCount, Exp, Expm1, Erf, Erfc, Floor, FloorDiv, FloorMod, | |||
| Ceil, Acosh, Greater, GreaterEqual, Lerp, Less, LessEqual, Log, Log1p, LogicalAnd, Mod, | |||
| LogicalNot, LogicalOr, LpNorm, MatMul, Maximum, MulNoNan, | |||
| MatrixDeterminant, LogMatrixDeterminant, Minimum, Mul, Neg, NMSWithMask, NotEqual, | |||
| NPUAllocFloatStatus, NPUClearFloatStatus, LinSpace, Einsum, | |||
| @@ -518,6 +518,7 @@ __all__ = [ | |||
| ] | |||
| __sponge__ = [ | |||
| "Cross", | |||
| "BondForce", | |||
| "BondEnergy", | |||
| "BondAtomEnergy", | |||
| @@ -5978,3 +5978,51 @@ class CholeskyInverse(Primitive): | |||
| """Initialize CholeskyInverse""" | |||
| validator.check_value_type("upper", upper, [bool], self.name) | |||
| self.upper = upper | |||
| class Cross(Primitive): | |||
| """ | |||
| Returns the cross product of vectors in dimension `dim` of x1 and x2. | |||
| x1 and x2 must have the same shape and the same type, and the size of their `dim` dimension should be 3. | |||
| If `dim` is not given, it defaults to the first dimension found with the size 3. | |||
| Args: | |||
| dim (int): The default value is -65530. | |||
| Inputs: | |||
| - **x1** (Tensor) - x1 is a tensor. | |||
| x1 and x2 must have the same shape and the same type, and the size of their `dim` dimension should be 3. | |||
| - **x2** (Tensor) - x2 is a tensor. | |||
| Outputs: | |||
| Tensor, has the same shape and type as input. | |||
| Raises: | |||
| TypeError: If `x1` is not a Tensor. | |||
| TypeError: If `x2` is not a Tensor. | |||
| TypeError: If the type of `x1` is not the same as that of `x2`. | |||
| ValueError: If `x1` and `x2` not have the same size, and the size of their `dim` dimension not be 3. | |||
| ValueError: If `x1` and `x2` not have the same shape. | |||
| ValueError: If `dim` is out of range, `dim` should be [-len(x1.shape), len(x1.shape)-1]. | |||
| Supported Platforms: | |||
| ``CPU`` | |||
| Examples: | |||
| >>> import mindspore | |||
| >>> import numpy as np | |||
| >>> from mindspore import Tensor | |||
| >>> from mindspore.common import dtype as mstype | |||
| >>> import mindspore.ops as ops | |||
| >>> cross = ops.Cross(dim = 0) | |||
| >>> x1 = Tensor([1, 2, 3], mstype.int8) | |||
| >>> x2 = Tensor([1, 2, 3], mstype.int8) | |||
| >>> output = cross(x1, x2) | |||
| >>> print(output) | |||
| [0, 0, 0] | |||
| """ | |||
| @prim_attr_register | |||
| def __init__(self, dim=-65530): | |||
| validator.check_value_type('dim', dim, [int], self.name) | |||
| self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['y']) | |||
| @@ -1084,6 +1084,11 @@ class ApplyKerasMomentumNet(nn.Cell): | |||
| test_case_math_ops = [ | |||
| ('Cross', { | |||
| 'block': P.Cross(dim=1), | |||
| 'desc_inputs': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8), | |||
| Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)], | |||
| 'desc_bprop': [Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mstype.int8)]}), | |||
| ('Ger', { | |||
| 'block': P.Ger(), | |||
| 'desc_inputs': [[3,], [4,]], | |||