| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "src/ops/arithmetic_compare.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| int ArithmeticCompare::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto res = Arithmetic::InferShape(inputs_, outputs_); | |||
| if (res == RET_OK) { | |||
| auto output = outputs_.front(); | |||
| output->set_data_type(TypeId::kNumberTypeBool); | |||
| return RET_OK; | |||
| } else { | |||
| return res; | |||
| } | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ | |||
| #define LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class ArithmeticCompare : public Arithmetic { | |||
| public: | |||
| ArithmeticCompare() = default; | |||
| ~ArithmeticCompare() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(ArithmeticCompare, Arithmetic); | |||
| explicit ArithmeticCompare(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| #endif // LITE_MINDSPORE_LITE_C_OPS_ARITHMETIC_COMPARE_H_ | |||
| @@ -35,16 +35,6 @@ int Equal::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers:: | |||
| PrimitiveC *EqualCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Equal>(primitive); } | |||
| Registry EqualRegistry(schema::PrimitiveType_Equal, EqualCreator); | |||
| #endif | |||
| int Equal::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(TypeId::kNumberTypeBool); | |||
| output->set_format(input->format()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,21 +20,20 @@ | |||
| #include <vector> | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/arithmetic_compare.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Equal : public Arithmetic { | |||
| class Equal : public ArithmeticCompare { | |||
| public: | |||
| Equal() = default; | |||
| ~Equal() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Equal, PrimitiveC); | |||
| explicit Equal(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| MS_DECLARE_PARENT(Equal, ArithmeticCompare); | |||
| explicit Equal(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -36,16 +36,6 @@ int Greater::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers | |||
| PrimitiveC *GreaterCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<Greater>(primitive); } | |||
| Registry GreaterRegistry(schema::PrimitiveType_Greater, GreaterCreator); | |||
| #endif | |||
| int Greater::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(TypeId::kNumberTypeBool); | |||
| output->set_format(input->format()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -20,21 +20,20 @@ | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/arithmetic_compare.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Greater : public Arithmetic { | |||
| class Greater : public ArithmeticCompare { | |||
| public: | |||
| Greater() = default; | |||
| ~Greater() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Greater, Arithmetic); | |||
| explicit Greater(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| MS_DECLARE_PARENT(Greater, ArithmeticCompare); | |||
| explicit Greater(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -38,16 +38,6 @@ PrimitiveC *GreaterEqualCreator(const schema::Primitive *primitive) { | |||
| Registry GreaterEqualRegistry(schema::PrimitiveType_GreaterEqual, GreaterEqualCreator); | |||
| #endif | |||
| int GreaterEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(TypeId::kNumberTypeBool); | |||
| output->set_format(input->format()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,21 +21,20 @@ | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/arithmetic_compare.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class GreaterEqual : public Arithmetic { | |||
| class GreaterEqual : public ArithmeticCompare { | |||
| public: | |||
| GreaterEqual() = default; | |||
| ~GreaterEqual() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(GreaterEqual, Arithmetic); | |||
| explicit GreaterEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| MS_DECLARE_PARENT(GreaterEqual, ArithmeticCompare); | |||
| explicit GreaterEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -38,16 +38,6 @@ PrimitiveC *LessCreator(const schema::Primitive *primitive) { return PrimitiveC: | |||
| Registry LessRegistry(schema::PrimitiveType_Less, LessCreator); | |||
| #endif | |||
| int Less::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(TypeId::kNumberTypeBool); | |||
| output->set_format(input->format()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,21 +21,20 @@ | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/arithmetic_compare.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class Less : public Arithmetic { | |||
| class Less : public ArithmeticCompare { | |||
| public: | |||
| Less() = default; | |||
| ~Less() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(Less, Arithmetic); | |||
| explicit Less(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| MS_DECLARE_PARENT(Less, ArithmeticCompare); | |||
| explicit Less(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -37,16 +37,6 @@ PrimitiveC *LessEqualCreator(const schema::Primitive *primitive) { | |||
| } | |||
| Registry LessEqualRegistry(schema::PrimitiveType_LessEqual, LessEqualCreator); | |||
| #endif | |||
| int LessEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(TypeId::kNumberTypeBool); | |||
| output->set_format(input->format()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,21 +21,20 @@ | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/arithmetic_compare.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class LessEqual : public Arithmetic { | |||
| class LessEqual : public ArithmeticCompare { | |||
| public: | |||
| LessEqual() = default; | |||
| ~LessEqual() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(LessEqual, Arithmetic); | |||
| explicit LessEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| MS_DECLARE_PARENT(LessEqual, ArithmeticCompare); | |||
| explicit LessEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -112,9 +112,17 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||
| input0->set_shape(a_shape); | |||
| } | |||
| if (a_shape.size() < 2 || b_shape.size() < 2) { | |||
| MS_LOG(ERROR) << "inputs shape is invalid"; | |||
| return RET_INPUT_TENSOR_ERROR; | |||
| bool del_start = false; | |||
| bool del_end = false; | |||
| if (a_shape.size() == 1) { | |||
| a_shape.insert(a_shape.begin(), 1); | |||
| input0->set_shape(a_shape); | |||
| del_start = true; | |||
| } | |||
| if (b_shape.size() == 1) { | |||
| b_shape.push_back(1); | |||
| input1->set_shape(b_shape); | |||
| del_end = true; | |||
| } | |||
| for (size_t i = 0; i < (a_shape.size() - 2) && i < (b_shape.size() - 2); ++i) { | |||
| if (a_shape[a_shape.size() - 3 - i] != b_shape[b_shape.size() - 3 - i]) { | |||
| @@ -131,6 +139,12 @@ int MatMul::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outp | |||
| } | |||
| std::vector<int> c_shape(a_shape); | |||
| c_shape[c_shape.size() - 1] = b_shape[b_shape.size() - 1]; | |||
| if (del_start) { | |||
| c_shape.erase(c_shape.begin()); | |||
| } | |||
| if (del_end) { | |||
| c_shape.pop_back(); | |||
| } | |||
| output->set_shape(c_shape); | |||
| return RET_OK; | |||
| } | |||
| @@ -38,16 +38,6 @@ PrimitiveC *NotEqualCreator(const schema::Primitive *primitive) { | |||
| Registry NotEqualRegistry(schema::PrimitiveType_NotEqual, NotEqualCreator); | |||
| #endif | |||
| int NotEqual::InferShape(std::vector<Tensor *> inputs_, std::vector<Tensor *> outputs_) { | |||
| auto input = inputs_.front(); | |||
| MS_ASSERT(input != nullptr); | |||
| auto output = outputs_.front(); | |||
| MS_ASSERT(output != nullptr); | |||
| output->set_shape(input->shape()); | |||
| output->set_data_type(TypeId::kNumberTypeBool); | |||
| output->set_format(input->format()); | |||
| return RET_OK; | |||
| } | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -21,21 +21,20 @@ | |||
| #include <set> | |||
| #include <cmath> | |||
| #include "src/ops/arithmetic.h" | |||
| #include "src/ops/arithmetic_compare.h" | |||
| namespace mindspore { | |||
| namespace lite { | |||
| class NotEqual : public Arithmetic { | |||
| class NotEqual : public ArithmeticCompare { | |||
| public: | |||
| NotEqual() = default; | |||
| ~NotEqual() = default; | |||
| #ifdef PRIMITIVE_WRITEABLE | |||
| MS_DECLARE_PARENT(NotEqual, Arithmetic); | |||
| explicit NotEqual(schema::PrimitiveT *primitive) : Arithmetic(primitive) {} | |||
| MS_DECLARE_PARENT(NotEqual, ArithmeticCompare); | |||
| explicit NotEqual(schema::PrimitiveT *primitive) : ArithmeticCompare(primitive) {} | |||
| #else | |||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||
| #endif | |||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||
| }; | |||
| } // namespace lite | |||
| } // namespace mindspore | |||
| @@ -58,11 +58,11 @@ Registry RealDivParameterRegistry(schema::PrimitiveType_RealDiv, PopulateArithme | |||
| Registry LogicalAndParameterRegistry(schema::PrimitiveType_LogicalAnd, PopulateArithmetic); | |||
| Registry ParameterRegistry(schema::PrimitiveType_LogicalOr, PopulateArithmetic); | |||
| Registry EqualParameterRegistry(schema::PrimitiveType_Equal, PopulateArithmetic); | |||
| Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); | |||
| Registry LessParameterRegistry(schema::PrimitiveType_Less, PopulateArithmetic); | |||
| Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); | |||
| Registry GreaterParameterRegistry(schema::PrimitiveType_Greater, PopulateArithmetic); | |||
| Registry GreaterEqualParameterRegistry(schema::PrimitiveType_GreaterEqual, PopulateArithmetic); | |||
| Registry NotEqualParameterRegistry(schema::PrimitiveType_NotEqual, PopulateArithmetic); | |||
| Registry LessEqualParameterRegistry(schema::PrimitiveType_LessEqual, PopulateArithmetic); | |||
| Registry MaximumParameterRegistry(schema::PrimitiveType_Maximum, PopulateArithmetic); | |||
| Registry MinimumParameterRegistry(schema::PrimitiveType_Minimum, PopulateArithmetic); | |||
| Registry FloorDivParameterRegistry(schema::PrimitiveType_FloorDiv, PopulateArithmetic); | |||
| @@ -28,77 +28,82 @@ using mindspore::schema::PrimitiveType_LessEqual; | |||
| using mindspore::schema::PrimitiveType_NotEqual; | |||
| namespace mindspore::kernel { | |||
| namespace { | |||
| typedef struct { | |||
| int primitive_type_; | |||
| ArithmeticCompareFp32Func func_; | |||
| } TYPE_FUNC_INFO; | |||
| } // namespace | |||
| ArithmeticCompareFp32Func ArithmeticCompareCPUKernel::GetArithmeticCompareFun(int primitive_type) { | |||
| TYPE_FUNC_INFO type_func_table[] = { | |||
| {PrimitiveType_Equal, ElementEqualFp32}, {PrimitiveType_NotEqual, ElementNotEqualFp32}, | |||
| {PrimitiveType_Less, ElementLessFp32}, {PrimitiveType_LessEqual, ElementLessEqualFp32}, | |||
| {PrimitiveType_Greater, ElementGreaterFp32}, {PrimitiveType_GreaterEqual, ElementGreaterEqualFp32}}; | |||
| for (size_t i = 0; i < sizeof(type_func_table); i++) { | |||
| if (type_func_table[i].primitive_type_ == primitive_type) { | |||
| return type_func_table[i].func_; | |||
| int ArithmeticCompareCPUKernel::BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, | |||
| int out_thread_stride) { | |||
| if (dim > break_pos_) { | |||
| if (data_type_ == kDataTypeInt) { | |||
| return func_int32_(reinterpret_cast<int *>(input0) + out_thread_stride, | |||
| reinterpret_cast<int *>(input1) + out_thread_stride, | |||
| reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); | |||
| } | |||
| return func_fp32_(reinterpret_cast<float *>(input0) + out_thread_stride, | |||
| reinterpret_cast<float *>(input1) + out_thread_stride, | |||
| reinterpret_cast<uint8_t *>(output) + out_thread_stride, out_count); | |||
| } | |||
| return nullptr; | |||
| } | |||
| int ArithmeticCompareCPUKernel::Init() { | |||
| if (!InferShapeDone()) { | |||
| return RET_OK; | |||
| for (int i = 0; i < arithmeticParameter_->out_shape_[dim]; ++i) { | |||
| int pos0_ = arithmeticParameter_->in_shape0_[dim] == 1 ? 0 : i; | |||
| int pos1_ = arithmeticParameter_->in_shape1_[dim] == 1 ? 0 : i; | |||
| int error_code; | |||
| if (data_type_ == kDataTypeInt) { | |||
| error_code = BroadcastRun(reinterpret_cast<int *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], | |||
| reinterpret_cast<int *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], | |||
| reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim], | |||
| dim + 1, out_count, out_thread_stride); | |||
| } else { | |||
| error_code = BroadcastRun(reinterpret_cast<float *>(input0) + pos0_ * arithmeticParameter_->in_strides0_[dim], | |||
| reinterpret_cast<float *>(input1) + pos1_ * arithmeticParameter_->in_strides1_[dim], | |||
| reinterpret_cast<uint8_t *>(output) + i * arithmeticParameter_->out_strides_[dim], | |||
| dim + 1, out_count, out_thread_stride); | |||
| } | |||
| if (error_code != RET_OK) { | |||
| return error_code; | |||
| } | |||
| } | |||
| return ReSize(); | |||
| return RET_OK; | |||
| } | |||
| int ArithmeticCompareCPUKernel::ReSize() { return RET_OK; } | |||
| int ArithmeticCompareCPUKernel::DoArithmetic(int task_id) { | |||
| auto element_num = out_tensors_[0]->ElementsNum(); | |||
| int ArithmeticCompareCPUKernel::DoExecute(int task_id) { | |||
| if (in_tensors_.at(0)->shape() != in_tensors_.at(1)->shape()) { | |||
| MS_LOG(ERROR) << "Compare op must inputs have the same shape, support broadcast later! "; | |||
| return RET_ERROR; | |||
| } | |||
| int elements_num = in_tensors_.at(0)->ElementsNum(); | |||
| int stride = UP_DIV(elements_num, op_parameter_->thread_num_); | |||
| int offset = task_id * stride; | |||
| int count = MSMIN(stride, elements_num - offset); | |||
| if (count <= 0) { | |||
| return RET_OK; | |||
| } | |||
| if (func_ == nullptr) { | |||
| MS_LOG(ERROR) << "Run function is null! "; | |||
| MS_ASSERT(thread_count_ != 0); | |||
| int stride = UP_DIV(element_num, thread_count_); | |||
| int count = MSMIN(stride, element_num - stride * task_id); | |||
| if (func_fp32_ == nullptr) { | |||
| MS_LOG(ERROR) << "func_fp32_ function is nullptr!"; | |||
| return RET_ERROR; | |||
| } | |||
| // two inputs have the same shape, support broadcast later | |||
| auto *input0_ptr = reinterpret_cast<float *>(in_tensors_.at(0)->MutableData()); | |||
| auto *input1_ptr = reinterpret_cast<float *>(in_tensors_.at(1)->MutableData()); | |||
| auto *output_ptr = reinterpret_cast<uint8_t *>(out_tensors_.at(0)->MutableData()); | |||
| auto ret = func_(input0_ptr + offset, input1_ptr + offset, output_ptr + offset, count); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "Run failed, illegal input! "; | |||
| } | |||
| return ret; | |||
| } | |||
| int ArithmeticCompareRun(void *cdata, int task_id) { | |||
| auto kernel = reinterpret_cast<ArithmeticCompareCPUKernel *>(cdata); | |||
| auto ret = kernel->DoExecute(task_id); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticSelfRuns error task_id[" << task_id << "] error_code[" << ret << "]"; | |||
| int error_code; | |||
| if (arithmeticParameter_->broadcasting_) { // need broadcast | |||
| stride = UP_DIV(outside_, thread_count_); | |||
| int out_count = MSMIN(stride, outside_ - stride * task_id); | |||
| int out_thread_stride = stride * task_id; | |||
| if (data_type_ == kDataTypeFloat) { | |||
| error_code = BroadcastRun( | |||
| reinterpret_cast<float *>(in_tensors_[0]->data_c()), reinterpret_cast<float *>(in_tensors_[1]->data_c()), | |||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); | |||
| } else { | |||
| error_code = BroadcastRun( | |||
| reinterpret_cast<int *>(in_tensors_[0]->data_c()), reinterpret_cast<int *>(in_tensors_[1]->data_c()), | |||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()), 0, out_count, out_thread_stride); | |||
| } | |||
| } else { // no broadcast, neither is scalar, two same shape | |||
| if (data_type_ == kDataTypeFloat) { | |||
| error_code = func_fp32_(reinterpret_cast<float *>(in_tensors_[0]->data_c()) + stride * task_id, | |||
| reinterpret_cast<float *>(in_tensors_[1]->data_c()) + stride * task_id, | |||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count); | |||
| } else { | |||
| error_code = func_int32_(reinterpret_cast<int *>(in_tensors_[0]->data_c()) + stride * task_id, | |||
| reinterpret_cast<int *>(in_tensors_[1]->data_c()) + stride * task_id, | |||
| reinterpret_cast<uint8_t *>(out_tensors_[0]->data_c()) + stride * task_id, count); | |||
| } | |||
| } | |||
| return ret; | |||
| } | |||
| int ArithmeticCompareCPUKernel::Run() { | |||
| auto ret = ParallelLaunch(this->context_->thread_pool_, ArithmeticCompareRun, this, op_parameter_->thread_num_); | |||
| if (ret != RET_OK) { | |||
| MS_LOG(ERROR) << "ArithmeticSelfRun error error_code[" << ret << "]"; | |||
| if (error_code != RET_OK) { | |||
| return RET_ERROR; | |||
| } | |||
| return ret; | |||
| return RET_OK; | |||
| } | |||
| kernel::LiteKernel *CpuArithmeticCompareFp32KernelCreator(const std::vector<lite::Tensor *> &inputs, | |||
| @@ -18,27 +18,57 @@ | |||
| #include <vector> | |||
| #include "src/runtime/kernel/arm/fp32/arithmetic_fp32.h" | |||
| #include "nnacl/fp32/arithmetic_compare_fp32.h" | |||
| namespace mindspore::kernel { | |||
| typedef int (*ArithmeticCompareFp32Func)(const float *input0, const float *input1, uint8_t *output, int element_size); | |||
| typedef int (*ArithmeticCompareIntFunc)(const int *input0, const int *input1, uint8_t *output, int element_size); | |||
| class ArithmeticCompareCPUKernel : public ArithmeticCPUKernel { | |||
| public: | |||
| explicit ArithmeticCompareCPUKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs, | |||
| const std::vector<lite::Tensor *> &outputs, const lite::InnerContext *ctx, | |||
| const mindspore::lite::PrimitiveC *primitive) | |||
| : ArithmeticCPUKernel(parameter, inputs, outputs, ctx, primitive) { | |||
| func_ = GetArithmeticCompareFun(parameter->type_); | |||
| switch (parameter->type_) { | |||
| case PrimitiveType_Equal: | |||
| func_fp32_ = ElementEqualFp32; | |||
| func_int32_ = ElementEqualInt32; | |||
| break; | |||
| case PrimitiveType_NotEqual: | |||
| func_fp32_ = ElementNotEqualFp32; | |||
| func_int32_ = ElementNotEqualInt32; | |||
| break; | |||
| case PrimitiveType_Less: | |||
| func_fp32_ = ElementLessFp32; | |||
| func_int32_ = ElementLessInt32; | |||
| break; | |||
| case PrimitiveType_LessEqual: | |||
| func_fp32_ = ElementLessEqualFp32; | |||
| func_int32_ = ElementLessEqualInt32; | |||
| break; | |||
| case PrimitiveType_Greater: | |||
| func_fp32_ = ElementGreaterFp32; | |||
| func_int32_ = ElementGreaterInt32; | |||
| break; | |||
| case PrimitiveType_GreaterEqual: | |||
| func_fp32_ = ElementGreaterEqualFp32; | |||
| func_int32_ = ElementGreaterEqualInt32; | |||
| break; | |||
| default: | |||
| MS_LOG(ERROR) << "Error Operator type " << parameter->type_; | |||
| func_fp32_ = nullptr; | |||
| func_int32_ = nullptr; | |||
| break; | |||
| } | |||
| } | |||
| ~ArithmeticCompareCPUKernel() override = default; | |||
| int Init() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| virtual int DoExecute(int task_id); | |||
| int DoArithmetic(int task_id) override; | |||
| int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride) override; | |||
| private: | |||
| ArithmeticCompareFp32Func GetArithmeticCompareFun(int primitive_type); | |||
| ArithmeticCompareFp32Func func_; | |||
| ArithmeticCompareFp32Func func_fp32_ = nullptr; | |||
| ArithmeticCompareIntFunc func_int32_ = nullptr; | |||
| }; | |||
| int ArithmeticCompareRun(void *cdata, int task_id); | |||
| } // namespace mindspore::kernel | |||
| @@ -175,6 +175,15 @@ int ArithmeticCPUKernel::ReSize() { | |||
| break; | |||
| } | |||
| break; | |||
| case PrimitiveType_Equal: | |||
| case PrimitiveType_Less: | |||
| case PrimitiveType_Greater: | |||
| case PrimitiveType_NotEqual: | |||
| case PrimitiveType_LessEqual: | |||
| case PrimitiveType_GreaterEqual: | |||
| arithmetic_opt_run_ = nullptr; | |||
| arithmetic_opt_run_int_ = nullptr; | |||
| break; | |||
| default: | |||
| break; | |||
| } | |||
| @@ -167,19 +167,21 @@ class ArithmeticCPUKernel : public LiteKernel { | |||
| int PreProcess() override; | |||
| int ReSize() override; | |||
| int Run() override; | |||
| int DoArithmetic(int task_id); | |||
| virtual int DoArithmetic(int task_id); | |||
| virtual int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); | |||
| private: | |||
| int BroadcastRun(void *input0, void *input1, void *output, int dim, int out_count, int out_thread_stride); | |||
| protected: | |||
| int break_pos_ = 0; | |||
| int outside_ = 0; | |||
| int thread_count_ = 1; | |||
| ArithmeticParameter *arithmeticParameter_ = nullptr; | |||
| LiteDataType data_type_ = kDataTypeFloat; | |||
| private: | |||
| ArithmeticRun arithmetic_run_ = nullptr; | |||
| ArithmeticOptRun arithmetic_opt_run_ = nullptr; | |||
| ArithmeticIntRun arithmetic_run_int_ = nullptr; | |||
| ArithmeticOptIntRun arithmetic_opt_run_int_ = nullptr; | |||
| LiteDataType data_type_ = kDataTypeFloat; | |||
| }; | |||
| } // namespace mindspore::kernel | |||
| #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_FP32_ARITHMETIC_H_ | |||
| @@ -146,4 +146,5 @@ REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Sqrt, CpuArithmeticSelfInt8Kerne | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Rsqrt, CpuArithmeticSelfInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_Square, CpuArithmeticSelfInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeInt8, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) | |||
| REG_KERNEL(kCPU, kNumberTypeBool, PrimitiveType_LogicalNot, CpuArithmeticSelfInt8KernelCreator) | |||
| } // namespace mindspore::kernel | |||
| @@ -135,17 +135,17 @@ STATUS OnnxConvParser::Parse(const onnx::GraphProto &onnx_graph, const onnx::Nod | |||
| std::find_if(onnx_graph.initializer().begin(), onnx_graph.initializer().end(), | |||
| [onnx_conv_weight](const onnx::TensorProto &proto) { return proto.name() == onnx_conv_weight; }); | |||
| if (nodeIter == onnx_graph.initializer().end()) { | |||
| MS_LOG(ERROR) << "not find node: " << onnx_conv_weight; | |||
| return RET_ERROR; | |||
| } | |||
| std::vector<int> weight_shape; | |||
| auto size = (*nodeIter).dims_size(); | |||
| weight_shape.reserve(size); | |||
| for (int i = 0; i < size; ++i) { | |||
| weight_shape.emplace_back((*nodeIter).dims(i)); | |||
| MS_LOG(WARNING) << "not find node: " << onnx_conv_weight; | |||
| } else { | |||
| std::vector<int> weight_shape; | |||
| auto size = (*nodeIter).dims_size(); | |||
| weight_shape.reserve(size); | |||
| for (int i = 0; i < size; ++i) { | |||
| weight_shape.emplace_back((*nodeIter).dims(i)); | |||
| } | |||
| attr->channelOut = weight_shape[0]; | |||
| attr->channelIn = weight_shape[1] * attr->group; | |||
| } | |||
| attr->channelOut = weight_shape[0]; | |||
| attr->channelIn = weight_shape[1] * attr->group; | |||
| } else { | |||
| auto nodeIter = | |||
| std::find_if(onnx_graph.node().begin(), onnx_graph.node().end(), | |||
| @@ -231,15 +231,6 @@ const AnfNodePtr ConstFoldPass::Process(const FuncGraphPtr &func_graph, const An | |||
| output_tensors[m]->AddQuantParam(quant_arg); | |||
| } | |||
| } | |||
| // here, input_tensor's format need to be transposed nhwc according to fmkType, | |||
| // but for the time being, we only transpose the tensor with 0/1/2/3D. | |||
| // Others should be added in future. | |||
| for (auto &input_tensor : input_tensors) { | |||
| input_tensor->set_format(schema::Format::Format_NHWC); | |||
| if (input_tensor->shape().size() == 4) { | |||
| MS_LOG(INFO) << "init input_tensor format to nhwc"; | |||
| } | |||
| } | |||
| lite_primitive->InferShape(input_tensors, output_tensors); | |||
| auto primitive = lite_primitive.get(); | |||
| MS_ASSERT(primitive != nullptr); | |||