| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/arithmetic_compare.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| int ArithmeticCompare::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto res = Arithmetic::InferShape(inputs_, outputs_); | |||||
| if (res == RET_OK) { | |||||
| auto output = outputs_.front(); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| return RET_OK; | |||||
| } else { | |||||
| return res; | |||||
| } | |||||
| } | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/arithmetic.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class ArithmeticCompare : public Arithmetic { | |||||
| public: | |||||
| ArithmeticCompare() = default; | |||||
| ~ArithmeticCompare() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(ArithmeticCompare, Arithmetic); | |||||
| explicit ArithmeticCompare(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ | |||||
| @@ -35,16 +35,6 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||||
| PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); } | PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); } | ||||
| Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); | Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); | ||||
| #endif | #endif | ||||
| int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->set_format(input->format()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,21 +20,20 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic.h" | |||||
| #include "src/ops/arithmetic_compare.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class Equal : public Arithmetic { | |||||
| class Equal : public ArithmeticCompare { | |||||
| public: | public: | ||||
| Equal() = default; | Equal() = default; | ||||
| ~Equal() = default; | ~Equal() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(Equal, PrimitiveC); | |||||
| explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| MS_DECLARE_PARENT(Equal, ArithmeticCompare); | |||||
| explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -36,16 +36,6 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||||
| PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); } | PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); } | ||||
| Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); | Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); | ||||
| #endif | #endif | ||||
| int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->set_format(input->format()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -20,21 +20,20 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic.h" | |||||
| #include "src/ops/arithmetic_compare.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class Greater : public Arithmetic { | |||||
| class Greater : public ArithmeticCompare { | |||||
| public: | public: | ||||
| Greater() = default; | Greater() = default; | ||||
| ~Greater() = default; | ~Greater() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(Greater, Arithmetic); | |||||
| explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| MS_DECLARE_PARENT(Greater, ArithmeticCompare); | |||||
| explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,16 +38,6 @@ PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) { | |||||
| Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); | Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); | ||||
| #endif | #endif | ||||
| int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->set_format(input->format()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,21 +21,20 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic.h" | |||||
| #include "src/ops/arithmetic_compare.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class GreaterEqual : public Arithmetic { | |||||
| class GreaterEqual : public ArithmeticCompare { | |||||
| public: | public: | ||||
| GreaterEqual() = default; | GreaterEqual() = default; | ||||
| ~GreaterEqual() = default; | ~GreaterEqual() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(GreaterEqual, Arithmetic); | |||||
| explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| MS_DECLARE_PARENT(GreaterEqual, ArithmeticCompare); | |||||
| explicit GreaterEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -38,16 +38,6 @@ PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC: | |||||
| Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); | Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); | ||||
| #endif | #endif | ||||
| int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->set_format(input->format()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,21 +21,20 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic.h" | |||||
| #include "src/ops/arithmetic_compare.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class Less : public Arithmetic { | |||||
| class Less : public ArithmeticCompare { | |||||
| public: | public: | ||||
| Less() = default; | Less() = default; | ||||
| ~Less() = default; | ~Less() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(Less, Arithmetic); | |||||
| explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| MS_DECLARE_PARENT(Less, ArithmeticCompare); | |||||
| explicit Less(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -37,16 +37,6 @@ PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) { | |||||
| } | } | ||||
| Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); | Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); | ||||
| #endif | #endif | ||||
| int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->set_format(input->format()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,21 +21,20 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic.h" | |||||
| #include "src/ops/arithmetic_compare.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class LessEqual : public Arithmetic { | |||||
| class LessEqual : public ArithmeticCompare { | |||||
| public: | public: | ||||
| LessEqual() = default; | LessEqual() = default; | ||||
| ~LessEqual() = default; | ~LessEqual() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(LessEqual, Arithmetic); | |||||
| explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| MS_DECLARE_PARENT(LessEqual, ArithmeticCompare); | |||||
| explicit LessEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -112,9 +112,17 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||||
| input0->set_shape(a_shape); | input0->set_shape(a_shape); | ||||
| } | } | ||||
| if (a_shape.size() < 2 || b_shape.size() < 2) { | |||||
| MS_LOG(ERROR) << "inputs shape is invalid"; | |||||
| return RET_INPUT_TENSOR_ERROR; | |||||
| bool del_start = false; | |||||
| bool del_end = false; | |||||
| if (a_shape.size() == 1) { | |||||
| a_shape.insert(a_shape.begin(), 1); | |||||
| input0->set_shape(a_shape); | |||||
| del_start = true; | |||||
| } | |||||
| if (b_shape.size() == 1) { | |||||
| b_shape.push_back(1); | |||||
| input1->set_shape(b_shape); | |||||
| del_end = true; | |||||
| } | } | ||||
| for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) { | for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) { | ||||
| if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) { | if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) { | ||||
| @@ -131,6 +139,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||||
| } | } | ||||
| std::vector<int> c_shape(a_shape); | std::vector<int> c_shape(a_shape); | ||||
| c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; | c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; | ||||
| if (del_start) { | |||||
| c_shape.erase(c_shape.begin()); | |||||
| } | |||||
| if (del_end) { | |||||
| c_shape.pop_back(); | |||||
| } | |||||
| output->set_shape(c_shape); | output->set_shape(c_shape); | ||||
| return RET_OK; | return RET_OK; | ||||
| } | } | ||||
| @@ -38,16 +38,6 @@ PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) { | |||||
| Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); | Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); | ||||
| #endif | #endif | ||||
| int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||||
| auto input = inputs_.front(); | |||||
| MS_ASSERT(input != nullptr); | |||||
| auto output = outputs_.front(); | |||||
| MS_ASSERT(output != nullptr); | |||||
| output->set_shape(input->shape()); | |||||
| output->set_data_type(TypeId::kNumberTypeBool); | |||||
| output->set_format(input->format()); | |||||
| return RET_OK; | |||||
| } | |||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -21,21 +21,20 @@ | |||||
| #include <set> | #include <set> | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "src/ops/arithmetic.h" | |||||
| #include "src/ops/arithmetic_compare.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace lite { | namespace lite { | ||||
| class NotEqual : public Arithmetic { | |||||
| class NotEqual : public ArithmeticCompare { | |||||
| public: | public: | ||||
| NotEqual() = default; | NotEqual() = default; | ||||
| ~NotEqual() = default; | ~NotEqual() = default; | ||||
| #ifdef PRIMITIVE_WRITEABLE | #ifdef PRIMITIVE_WRITEABLE | ||||
| MS_DECLARE_PARENT(NotEqual, Arithmetic); | |||||
| explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||||
| MS_DECLARE_PARENT(NotEqual, ArithmeticCompare); | |||||
| explicit NotEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||||
| #else | #else | ||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | ||||
| #endif | #endif | ||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| }; | }; | ||||
| } // namespace lite | } // namespace lite | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -58,11 +58,11 @@ Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithme | |||||
| Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); | Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); | ||||
| Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); | Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); | ||||
| Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); | Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); | ||||
| Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); | |||||
| Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); | Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); | ||||
| Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); | |||||
| Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); | Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); | ||||
| Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); | Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); | ||||
| Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); | |||||
| Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); | |||||
| Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); | Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); | ||||
| Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); | Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); | ||||
| Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | ||||
| @@ -28,77 +28,82 @@ using mindspore::schema::PrimitiveType_LessEqual; | |||||
| using mindspore::schema::PrimitiveType_NotEqual; | using mindspore::schema::PrimitiveType_NotEqual; | ||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| namespace { | |||||
| typedef struct { | |||||
| int primitive_type_; | |||||
| ArithmeticCompareFp32Func func_; | |||||
| } TYPE_FUNC_INFO; | |||||
| } // namespace | |||||
| ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) { | |||||
| TYPE_FUNC_INFO type_func_table[] = { | |||||
| {PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32}, | |||||
| {PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32}, | |||||
| {PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}}; | |||||
| for (size_t i = 0; i < sizeof(type_func_table); i++) { | |||||
| if (type_func_table[i].primitive_type_ == primitive_type) { | |||||
| return type_func_table[i].func_; | |||||
| int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, | |||||
| int out_thread_stride) { | |||||
| if (dim > break_pos_) { | |||||
| if (data_type_ == kDataTypeInt) { | |||||
| return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride, | |||||
| reinterpret_cast<int *>(input1) + out_thread_stride, | |||||
| reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); | |||||
| } | } | ||||
| return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride, | |||||
| reinterpret_cast<float *>(input1) + out_thread_stride, | |||||
| reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); | |||||
| } | } | ||||
| return nullptr; | |||||
| } | |||||
| int ArithmeticCompareCPUKernel::Init() { | |||||
| if (!InferShapeDone()) { | |||||
| return RET_OK; | |||||
| for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { | |||||
| int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; | |||||
| int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; | |||||
| int error_code; | |||||
| if (data_type_ == kDataTypeInt) { | |||||
| error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], | |||||
| reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], | |||||
| reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim], | |||||
| dim + 1, out_count, out_thread_stride); | |||||
| } else { | |||||
| error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], | |||||
| reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], | |||||
| reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim], | |||||
| dim + 1, out_count, out_thread_stride); | |||||
| } | |||||
| if (error_code != RET_OK) { | |||||
| return error_code; | |||||
| } | |||||
| } | } | ||||
| return ReSize(); | |||||
| return RET_OK; | |||||
| } | } | ||||
| int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; } | |||||
| int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { | |||||
| auto element_num = out_tensors_[0]->ElementsNum(); | |||||
| int ArithmeticCompareCPUKernel::DoExecute(int task_id) { | |||||
| if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) { | |||||
| MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! "; | |||||
| return RET_ERROR; | |||||
| } | |||||
| int elements_num = in_tensors_.at(0)->ElementsNum(); | |||||
| int stride = UP_DIV(elements_num, op_parameter_->thread_num_); | |||||
| int offset = task_id * stride; | |||||
| int count = MSMIN(stride, elements_num - offset); | |||||
| if (count <= 0) { | |||||
| return RET_OK; | |||||
| } | |||||
| if (func_ == nullptr) { | |||||
| MS_LOG(ERROR) << "Run function is null! "; | |||||
| MS_ASSERT(thread_count_ != 0); | |||||
| int stride = UP_DIV(element_num, thread_count_); | |||||
| int count = MSMIN(stride, element_num - stride * task_id); | |||||
| if (func_fp32_ == nullptr) { | |||||
| MS_LOG(ERROR) << "func_fp32_ function is nullptr!"; | |||||
| return RET_ERROR; | return RET_ERROR; | ||||
| } | } | ||||
| // two inputs have the same shape, support broadcast later | |||||
| auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||||
| auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||||
| auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData()); | |||||
| auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "Run failed, illegal input! "; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| int ArithmeticCompareRun(void *cdata, int task_id) { | |||||
| auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata); | |||||
| auto ret = kernel->DoExecute(task_id); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; | |||||
| int error_code; | |||||
| if (arithmeticParameter_->broadcasting_) { // need broadcast | |||||
| stride = UP_DIV(outside_, thread_count_); | |||||
| int out_count = MSMIN(stride, outside_ - stride * task_id); | |||||
| int out_thread_stride = stride * task_id; | |||||
| if (data_type_ == kDataTypeFloat) { | |||||
| error_code = BroadcastRun( | |||||
| reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()), | |||||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); | |||||
| } else { | |||||
| error_code = BroadcastRun( | |||||
| reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()), | |||||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); | |||||
| } | |||||
| } else { // no broadcast, neither is scalar, two same shape | |||||
| if (data_type_ == kDataTypeFloat) { | |||||
| error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id, | |||||
| reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id, | |||||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count); | |||||
| } else { | |||||
| error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id, | |||||
| reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id, | |||||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count); | |||||
| } | |||||
| } | } | ||||
| return ret; | |||||
| } | |||||
| int ArithmeticCompareCPUKernel::Run() { | |||||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_); | |||||
| if (ret != RET_OK) { | |||||
| MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; | |||||
| if (error_code != RET_OK) { | |||||
| return RET_ERROR; | |||||
| } | } | ||||
| return ret; | |||||
| return RET_OK; | |||||
| } | } | ||||
| kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | ||||
| @@ -18,27 +18,57 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" | #include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" | ||||
| #include "nnacl/fp32/arithmetic_compare_fp32.h" | |||||
| namespace mindspore::kernel { | namespace mindspore::kernel { | ||||
| typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); | typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); | ||||
| typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size); | |||||
| class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { | class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { | ||||
| public: | public: | ||||
| explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | ||||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | ||||
| const mindspore::lite::PrimitiveC *primitive) | const mindspore::lite::PrimitiveC *primitive) | ||||
| : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { | : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { | ||||
| func_ = GetArithmeticCompareFun(parameter->type_); | |||||
| switch (parameter->type_) { | |||||
| case PrimitiveType_Equal: | |||||
| func_fp32_ = ElementEqualFp32; | |||||
| func_int32_ = ElementEqualInt32; | |||||
| break; | |||||
| case PrimitiveType_NotEqual: | |||||
| func_fp32_ = ElementNotEqualFp32; | |||||
| func_int32_ = ElementNotEqualInt32; | |||||
| break; | |||||
| case PrimitiveType_Less: | |||||
| func_fp32_ = ElementLessFp32; | |||||
| func_int32_ = ElementLessInt32; | |||||
| break; | |||||
| case PrimitiveType_LessEqual: | |||||
| func_fp32_ = ElementLessEqualFp32; | |||||
| func_int32_ = ElementLessEqualInt32; | |||||
| break; | |||||
| case PrimitiveType_Greater: | |||||
| func_fp32_ = ElementGreaterFp32; | |||||
| func_int32_ = ElementGreaterInt32; | |||||
| break; | |||||
| case PrimitiveType_GreaterEqual: | |||||
| func_fp32_ = ElementGreaterEqualFp32; | |||||
| func_int32_ = ElementGreaterEqualInt32; | |||||
| break; | |||||
| default: | |||||
| MS_LOG(ERROR) << "Error Operator type " << parameter->type_; | |||||
| func_fp32_ = nullptr; | |||||
| func_int32_ = nullptr; | |||||
| break; | |||||
| } | |||||
| } | } | ||||
| ~ArithmeticCompareCPUKernel() override = default; | ~ArithmeticCompareCPUKernel() override = default; | ||||
| int Init() override; | |||||
| int ReSize() override; | |||||
| int Run() override; | |||||
| virtual int DoExecute(int task_id); | |||||
| int DoArithmetic(int task_id) override; | |||||
| int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override; | |||||
| private: | private: | ||||
| ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type); | |||||
| ArithmeticCompareFp32Func func_; | |||||
| ArithmeticCompareFp32Func func_fp32_ = nullptr; | |||||
| ArithmeticCompareIntFunc func_int32_ = nullptr; | |||||
| }; | }; | ||||
| int ArithmeticCompareRun(void *cdata, int task_id); | int ArithmeticCompareRun(void *cdata, int task_id); | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -175,6 +175,15 @@ int ArithmeticCPUKernel::ReSize() { | |||||
| break; | break; | ||||
| } | } | ||||
| break; | break; | ||||
| case PrimitiveType_Equal: | |||||
| case PrimitiveType_Less: | |||||
| case PrimitiveType_Greater: | |||||
| case PrimitiveType_NotEqual: | |||||
| case PrimitiveType_LessEqual: | |||||
| case PrimitiveType_GreaterEqual: | |||||
| arithmetic_opt_run_ = nullptr; | |||||
| arithmetic_opt_run_int_ = nullptr; | |||||
| break; | |||||
| default: | default: | ||||
| break; | break; | ||||
| } | } | ||||
| @@ -167,19 +167,21 @@ class ArithmeticCPUKernel : public LiteKernel { | |||||
| int PreProcess() override; | int PreProcess() override; | ||||
| int ReSize() override; | int ReSize() override; | ||||
| int Run() override; | int Run() override; | ||||
| int DoArithmetic(int task_id); | |||||
| virtual int DoArithmetic(int task_id); | |||||
| virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); | |||||
| private: | |||||
| int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); | |||||
| protected: | |||||
| int break_pos_ = 0; | int break_pos_ = 0; | ||||
| int outside_ = 0; | int outside_ = 0; | ||||
| int thread_count_ = 1; | int thread_count_ = 1; | ||||
| ArithmeticParameter *arithmeticParameter_ = nullptr; | ArithmeticParameter *arithmeticParameter_ = nullptr; | ||||
| LiteDataType data_type_ = kDataTypeFloat; | |||||
| private: | |||||
| ArithmeticRun arithmetic_run_ = nullptr; | ArithmeticRun arithmetic_run_ = nullptr; | ||||
| ArithmeticOptRun arithmetic_opt_run_ = nullptr; | ArithmeticOptRun arithmetic_opt_run_ = nullptr; | ||||
| ArithmeticIntRun arithmetic_run_int_ = nullptr; | ArithmeticIntRun arithmetic_run_int_ = nullptr; | ||||
| ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; | ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; | ||||
| LiteDataType data_type_ = kDataTypeFloat; | |||||
| }; | }; | ||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ | #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ | ||||
| @@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne | |||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) | REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) | ||||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) | |||||
| } // namespace mindspore::kernel | } // namespace mindspore::kernel | ||||
| @@ -135,17 +135,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||||
| std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | ||||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | ||||
| if (nodeIter == onnx_graph.initializer().end()) { | if (nodeIter == onnx_graph.initializer().end()) { | ||||
| MS_LOG(ERROR) << "not find node: " << onnx_conv_weight; | |||||
| return RET_ERROR; | |||||
| } | |||||
| std::vector<int> weight_shape; | |||||
| auto size = (*nodeIter).dims_size(); | |||||
| weight_shape.reserve(size); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| weight_shape.emplace_back((*nodeIter).dims(i)); | |||||
| MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; | |||||
| } else { | |||||
| std::vector<int> weight_shape; | |||||
| auto size = (*nodeIter).dims_size(); | |||||
| weight_shape.reserve(size); | |||||
| for (int i = 0; i < size; ++i) { | |||||
| weight_shape.emplace_back((*nodeIter).dims(i)); | |||||
| } | |||||
| attr->channelOut = weight_shape[0]; | |||||
| attr->channelIn = weight_shape[1] * attr->group; | |||||
| } | } | ||||
| attr->channelOut = weight_shape[0]; | |||||
| attr->channelIn = weight_shape[1] * attr->group; | |||||
| } else { | } else { | ||||
| auto nodeIter = | auto nodeIter = | ||||
| std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | ||||
| @@ -231,15 +231,6 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||||
| output_tensors[m]->AddQuantParam(quant_arg); | output_tensors[m]->AddQuantParam(quant_arg); | ||||
| } | } | ||||
| } | } | ||||
| // here, input_tensor's format need to be transposed nhwc according to fmkType, | |||||
| // but for the time being, we only transpose the tensor with 0/1/2/3D. | |||||
| // Others should be added in future. | |||||
| for (auto &input_tensor : input_tensors) { | |||||
| input_tensor->set_format(schema::Format::Format_NHWC); | |||||
| if (input_tensor->shape().size() == 4) { | |||||
| MS_LOG(INFO) << "init input_tensor format to nhwc"; | |||||
| } | |||||
| } | |||||
| lite_primitive->InferShape(input_tensors, output_tensors); | lite_primitive->InferShape(input_tensors, output_tensors); | ||||
| auto primitive = lite_primitive.get(); | auto primitive = lite_primitive.get(); | ||||
| MS_ASSERT(primitive != nullptr); | MS_ASSERT(primitive != nullptr); | ||||