| @@ -1,42 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/gelu.h" | |||||
| #include <memory> | |||||
| #include "include/errorcode.h" | |||||
| #include "src/common/log_adapter.h" | |||||
| #include "src/tensor.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| // int GeLU::GetApproximate() const { return this->primitive_->value.AsGeLU()->approximate; } | |||||
| // void GeLU::SetApproximate(bool approximate) { this->primitive_->value.AsGeLU()->approximate = approximate; } | |||||
| #else | |||||
| // int GeLU::GetApproximate() const { return this->primitive_->value_as_GeLU()->approximate(); } | |||||
| // PrimitiveC *GeLUCreator(const schema::Primitive *primitive) { return PrimitiveC::NewPrimitiveC<GeLU>(primitive); } | |||||
| // Registry GeLURegistry(schema::PrimitiveType_GeLU, GeLUCreator); | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,45 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class GeLU : public PrimitiveC { | |||||
| public: | |||||
| GeLU() = default; | |||||
| ~GeLU() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(GeLU, PrimitiveC); | |||||
| explicit GeLU(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| // int UnPackAttr(const Primitive &prim, const std::vector<AnfNodePtr> &inputs) override; | |||||
| // void SetApproximate(bool approximate); | |||||
| #else | |||||
| // int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int InferShape(std::vector<lite::Tensor *> inputs_, std::vector<lite::Tensor *> outputs_) override; | |||||
| // bool GetApproximate() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_GELU_H_ | |||||
| @@ -1,66 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/matrix_diag.h" | |||||
| #ifndef PRIMITIVE_WRITEABLE | |||||
| #include "src/ops/ops_register.h" | |||||
| #endif | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| int MatrixDiag::GetK() const { return this->primitive_->value.AsMatrixDiag()->k; } | |||||
| int MatrixDiag::GetNumRows() const { return this->primitive_->value.AsMatrixDiag()->numRows; } | |||||
| int MatrixDiag::GetNumCols() const { return this->primitive_->value.AsMatrixDiag()->numCols; } | |||||
| float MatrixDiag::GetPaddingValue() const { return this->primitive_->value.AsMatrixDiag()->paddingValue; } | |||||
| void MatrixDiag::SetK(int k) { this->primitive_->value.AsMatrixDiag()->k = k; } | |||||
| void MatrixDiag::SetNumRows(int num_rows) { this->primitive_->value.AsMatrixDiag()->numRows = num_rows; } | |||||
| void MatrixDiag::SetNumCols(int num_cols) { this->primitive_->value.AsMatrixDiag()->numCols = num_cols; } | |||||
| void MatrixDiag::SetPaddingValue(float padding_value) { | |||||
| this->primitive_->value.AsMatrixDiag()->paddingValue = padding_value; | |||||
| } | |||||
| #else | |||||
| int MatrixDiag::GetK() const { return this->primitive_->value_as_MatrixDiag()->k(); } | |||||
| int MatrixDiag::GetNumRows() const { return this->primitive_->value_as_MatrixDiag()->numRows(); } | |||||
| int MatrixDiag::GetNumCols() const { return this->primitive_->value_as_MatrixDiag()->numCols(); } | |||||
| float MatrixDiag::GetPaddingValue() const { return this->primitive_->value_as_MatrixDiag()->paddingValue(); } | |||||
| int MatrixDiag::UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) { | |||||
| MS_ASSERT(nullptr != primitive); | |||||
| MS_ASSERT(nullptr != fbb); | |||||
| auto attr = primitive->value_as_MatrixDiag(); | |||||
| if (attr == nullptr) { | |||||
| MS_LOG(ERROR) << "value_as_MatrixDiag return nullptr"; | |||||
| return RET_ERROR; | |||||
| } | |||||
| auto val_offset = schema::CreateMatrixDiag(*fbb, attr->k(), attr->numRows(), attr->numCols(), attr->paddingValue()); | |||||
| auto prim_offset = schema::CreatePrimitive(*fbb, schema::PrimitiveType_MatrixDiag, val_offset.o); | |||||
| fbb->Finish(prim_offset); | |||||
| return RET_OK; | |||||
| } | |||||
| PrimitiveC *MatrixDiagCreator(const schema::Primitive *primitive) { | |||||
| return PrimitiveC::NewPrimitiveC<MatrixDiag>(primitive); | |||||
| } | |||||
| Registry MatrixDiagRegistry(schema::PrimitiveType_MatrixDiag, MatrixDiagCreator); | |||||
| #endif | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| @@ -1,50 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef LITE_MINDSPORE_LITE_C_OPS_MATRIX_DIAG_H_ | |||||
| #define LITE_MINDSPORE_LITE_C_OPS_MATRIX_DIAG_H_ | |||||
| #include <vector> | |||||
| #include <set> | |||||
| #include <cmath> | |||||
| #include "src/ops/primitive_c.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| class MatrixDiag : public PrimitiveC { | |||||
| public: | |||||
| MatrixDiag() = default; | |||||
| ~MatrixDiag() = default; | |||||
| #ifdef PRIMITIVE_WRITEABLE | |||||
| MS_DECLARE_PARENT(MatrixDiag, PrimitiveC); | |||||
| explicit MatrixDiag(schema::PrimitiveT *primitive) : PrimitiveC(primitive) {} | |||||
| void SetK(int k); | |||||
| void SetNumRows(int num_rows); | |||||
| void SetNumCols(int num_cols); | |||||
| void SetPaddingValue(float padding_value); | |||||
| #else | |||||
| int UnPackToFlatBuilder(const schema::Primitive *primitive, flatbuffers::FlatBufferBuilder *fbb) override; | |||||
| #endif | |||||
| int GetK() const; | |||||
| int GetNumRows() const; | |||||
| int GetNumCols() const; | |||||
| float GetPaddingValue() const; | |||||
| }; | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||
| #endif // LITE_MINDSPORE_LITE_C_OPS_MATRIX_DIAG_H_ | |||||
| @@ -1,40 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "src/ops/primitive_c.h" | |||||
| #include "src/ops/populate/populate_register.h" | |||||
| #include "src/ops/gelu.h" | |||||
| #include "nnacl/gelu_parameter.h" | |||||
| namespace mindspore { | |||||
| namespace lite { | |||||
| OpParameter *PopulateGeLUParameter(const mindspore::lite::PrimitiveC *primitive) { | |||||
| // const auto param = reinterpret_cast<mindspore::lite::GeLU *>(const_cast<mindspore::lite::PrimitiveC | |||||
| // *>(primitive)); | |||||
| GeLUParameter *gelu_param = reinterpret_cast<GeLUParameter *>(malloc(sizeof(GeLUParameter))); | |||||
| if (gelu_param == nullptr) { | |||||
| MS_LOG(ERROR) << "malloc GeLUParameter failed."; | |||||
| return nullptr; | |||||
| } | |||||
| memset(gelu_param, 0, sizeof(GeLUParameter)); | |||||
| gelu_param->op_parameter_.type_ = primitive->Type(); | |||||
| // gelu_param->approximate_ = param->GetApproximate(); | |||||
| return reinterpret_cast<OpParameter *>(gelu_param); | |||||
| } | |||||
| // Registry GeLUParameterRegistry(schema::PrimitiveType_GeLU, PopulateGeLUParameter); | |||||
| } // namespace lite | |||||
| } // namespace mindspore | |||||