| @@ -624,7 +624,7 @@ TVM_REGISTER_GLOBAL("BroadcastTo").set_body([](TVMArgs args, TVMRetValue *rv) { | |||
| } | |||
| }); | |||
| TVM_REGISTER_GLOBAL("BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { | |||
| TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { | |||
| CHECK_GE(args.size(), 2); | |||
| auto inputs = args[0].operator Array<NodeRef>(); | |||
| auto attrs = args[1].operator OpAttr(); | |||
| @@ -718,7 +718,7 @@ TVM_REGISTER_GLOBAL("BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { | |||
| }); | |||
| // only support fractal_zN: [ko mo mi ki] * [no ko ki ni] = [no mo mi ni] | |||
| TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) { | |||
| TVM_REGISTER_GLOBAL("aicore_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) { | |||
| CHECK_GE(args.size(), 2); | |||
| auto attrs = args[1].operator OpAttr(); | |||
| CHECK(attrs.count("transpose_a")); | |||
| @@ -743,7 +743,7 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) | |||
| auto left_shape = left_matrix->shape; | |||
| auto right_shape = right_matrix->shape; | |||
| CHECK_EQ(left_shape.size(), right_shape.size()); | |||
| CHECK_EQ(left_shape.size(), 4); | |||
| CHECK_GE(left_shape.size(), 4); | |||
| auto type_checker = [](const Tensor &input_data, const std::string name, const air::DataType type) { | |||
| if (input_data->dtype != type) { | |||
| @@ -757,26 +757,33 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) | |||
| Array<Expr> output_shape; | |||
| Array<Expr> k; | |||
| auto compute_mnk = [&output_shape, &k, &left_shape, &right_shape, transpose_a, transpose_b]() { | |||
| size_t dim = left_shape.size(); | |||
| Expr mo, mi, no, ni, ko, ki; | |||
| if (transpose_a) { | |||
| mo = left_shape[0]; | |||
| ko = left_shape[1]; | |||
| ki = left_shape[2]; | |||
| mi = left_shape[3]; | |||
| mo = left_shape[dim - 4]; | |||
| ko = left_shape[dim - 3]; | |||
| ki = left_shape[dim - 2]; | |||
| mi = left_shape[dim - 1]; | |||
| } else { | |||
| ko = left_shape[0]; | |||
| mo = left_shape[1]; | |||
| mi = left_shape[2]; | |||
| ki = left_shape[3]; | |||
| ko = left_shape[dim - 4]; | |||
| mo = left_shape[dim - 3]; | |||
| mi = left_shape[dim - 2]; | |||
| ki = left_shape[dim - 1]; | |||
| } | |||
| if (transpose_b) { | |||
| no = right_shape[1]; | |||
| ni = right_shape[2]; | |||
| no = right_shape[dim - 3]; | |||
| ni = right_shape[dim - 2]; | |||
| } else { | |||
| no = right_shape[0]; | |||
| ni = right_shape[3]; | |||
| no = right_shape[dim - 4]; | |||
| ni = right_shape[dim - 1]; | |||
| } | |||
| output_shape = {no, mo, mi, ni}; | |||
| for (size_t i = 0; i < dim - 4; ++i) { | |||
| output_shape.push_back(left_shape[i]); | |||
| } | |||
| output_shape.push_back(no); | |||
| output_shape.push_back(mo); | |||
| output_shape.push_back(mi); | |||
| output_shape.push_back(ni); | |||
| k = {ko, ki}; | |||
| }; | |||
| @@ -795,22 +802,47 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) | |||
| IterVar reduce_ki = air::reduce_axis(Range(0, k[1]), "ki"); | |||
| Array<IterVar> reduces = {reduce_ko, reduce_ki}; | |||
| auto fcompute = [&left_matrix, &right_matrix, &transpose_a, &transpose_b, &reduces, &output_shape, | |||
| auto fcompute = [&left_matrix, &right_matrix, &transpose_a, &transpose_b, &reduces, | |||
| &Mmad](const Array<Var> &indices) { | |||
| Array<Expr> left_indice = {reduces[0], indices[1], indices[2], reduces[1]}; | |||
| Array<Expr> right_indice = {indices[0], reduces[0], reduces[1], indices[3]}; | |||
| size_t dim = indices.size(); | |||
| Array<Expr> left_indice; | |||
| for (size_t i = 0; i < dim - 4; ++i) { | |||
| left_indice.push_back(indices[i]); | |||
| } | |||
| if (transpose_a) { | |||
| left_indice = {indices[1], reduces[0], reduces[1], indices[2]}; | |||
| left_indice.push_back(indices[dim - 3]); | |||
| left_indice.push_back(reduces[0]); | |||
| left_indice.push_back(reduces[1]); | |||
| left_indice.push_back(indices[dim - 2]); | |||
| } else { | |||
| left_indice.push_back(reduces[0]); | |||
| left_indice.push_back(indices[dim - 3]); | |||
| left_indice.push_back(indices[dim - 2]); | |||
| left_indice.push_back(reduces[1]); | |||
| } | |||
| Array<Expr> right_indice; | |||
| for (size_t i = 0; i < dim - 4; ++i) { | |||
| right_indice.push_back(indices[i]); | |||
| } | |||
| if (transpose_b) { | |||
| right_indice = {reduces[0], indices[0], indices[3], reduces[1]}; | |||
| right_indice.push_back(reduces[0]); | |||
| right_indice.push_back(indices[dim - 4]); | |||
| right_indice.push_back(indices[dim - 1]); | |||
| right_indice.push_back(reduces[1]); | |||
| } else { | |||
| right_indice.push_back(indices[dim - 4]); | |||
| right_indice.push_back(reduces[0]); | |||
| right_indice.push_back(reduces[1]); | |||
| right_indice.push_back(indices[dim - 1]); | |||
| } | |||
| Expr res = Mmad(Cast::make(Float(32), left_matrix(left_indice) * right_matrix(right_indice)), reduces); | |||
| return res; | |||
| }; | |||
| // set output name | |||
| auto name = "T_matmul_" + left_matrix->op->name + "_" + right_matrix->op->name; | |||
| auto name = "T_batchmatmul_" + left_matrix->op->name + "_" + right_matrix->op->name; | |||
| // set compute attrs | |||
| auto set_compute_attrs_zN = [&left_matrix, &right_matrix, &inputs, transpose_a, transpose_b, attrs]() { | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "composite/optimize/optimize.h" | |||
| #include <memory> | |||
| #include "composite/optimize/rename_matmul.h" | |||
| #include "composite/optimize/reshape_tensor.h" | |||
| #include "composite/optimize/elim_transform_op.h" | |||
| #include "composite/optimize/inplace_assign_mutator.h" | |||
| @@ -51,6 +52,8 @@ Stmt Optimize(Stmt &s, BuildInfo &info) { | |||
| if (info.opt.target == "aicore") { | |||
| pm.RegisterPass(std::make_shared<TypeCastInserter>()); | |||
| } | |||
| // rename MatMul to BatchMatMul | |||
| pm.RegisterPass(std::make_shared<RenameMatmul>()); | |||
| s = pm.Run(s); | |||
| return s; | |||
| } | |||
| @@ -0,0 +1,39 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "composite/optimize/rename_matmul.h" | |||
| namespace akg { | |||
| // rename MatMul to BatchMatMul | |||
| class RenameMatmulMutator : public IRMutator { | |||
| public: | |||
| explicit RenameMatmulMutator() {} | |||
| ~RenameMatmulMutator() override = default; | |||
| Stmt Mutate_(const Provide *op, const Stmt &s) { | |||
| auto call = op->value.as<Call>(); | |||
| if (call == nullptr || call->name != "MatMul") { | |||
| return IRMutator::Mutate_(op, s); | |||
| } | |||
| return Provide::make(op->func, 0, | |||
| Call::make(op->value.type(), "BatchMatMul", call->args, Call::CallType::PureIntrinsic), | |||
| op->args); | |||
| } | |||
| }; | |||
| Stmt RenameMatmul::Run(const Stmt &s) { | |||
| return RenameMatmulMutator().Mutate(s); | |||
| } | |||
| } // namespace akg | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_ | |||
| #define COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_ | |||
| #include "composite/optimize/optimize.h" | |||
| namespace akg { | |||
| class RenameMatmul : public CompositeOptPass { | |||
| public: | |||
| RenameMatmul() { pass_name_ = __FUNCTION__; } | |||
| ~RenameMatmul() = default; | |||
| Stmt Run(const Stmt &s) override; | |||
| }; | |||
| } // namespace akg | |||
| #endif // COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_ | |||
| @@ -52,7 +52,8 @@ class ReshapeTensorMutator : public IRMutator { | |||
| } | |||
| Stmt Mutate_(const Provide *op, const Stmt &s) { | |||
| static std::unordered_set<std::string> check_list = {"TensorAdd", "Add", "RealDiv", "Mul", "Minimum", "Maximum", "Sub"}; | |||
| static std::unordered_set<std::string> check_list = {"TensorAdd", "Add", "RealDiv", "Mul", | |||
| "Minimum", "Maximum", "Sub"}; | |||
| auto call = op->value.as<Call>(); | |||
| if (call == nullptr || check_list.find(call->name) == check_list.end()) { | |||
| return IRMutator::Mutate_(op, s); | |||
| @@ -212,7 +213,7 @@ class ReshapeTensorMutator : public IRMutator { | |||
| } | |||
| auto call = op->value.as<Call>(); | |||
| return Provide::make(op->func, 0, Call::make(op->value.type(), call->name, input, Call::CallType::PureIntrinsic), | |||
| op->args); | |||
| op->args); | |||
| } | |||
| Stmt ModifyAttrMap(const AttrStmt *op, const Stmt &stmt, const Map<std::string, NodeRef> &attr_map) { | |||
| @@ -252,9 +253,9 @@ class ReshapeTensorMutator : public IRMutator { | |||
| for (const auto &it : reshape_) { | |||
| auto arg = | |||
| Call::make(it.first->dtype, it.first->op->name, it.first->shape, Call::CallType::Halide, it.first->op); | |||
| auto reshape_stmt = Provide::make( | |||
| it.second->op, 0, Call::make(it.first->dtype, "Reshape", {arg}, Call::CallType::PureIntrinsic), | |||
| it.second->shape); | |||
| auto reshape_stmt = | |||
| Provide::make(it.second->op, 0, Call::make(it.first->dtype, "Reshape", {arg}, Call::CallType::PureIntrinsic), | |||
| it.second->shape); | |||
| Map<std::string, NodeRef> attrs; | |||
| attrs.Set("shape", it.second->shape); | |||
| auto reshape_attr = AttrStmt::make(attrs, "attrs", Expr(1), reshape_stmt); | |||
| @@ -353,12 +354,11 @@ class ReshapeTensorMutator : public IRMutator { | |||
| } | |||
| return std::make_tuple(shape_long, shape_tmp, shape_out); | |||
| } | |||
| }; | |||
| // When Matmul has DefaultFormat bias, reshape bias to FRACTAL_NZ format | |||
| // If bias need pad, do pad as | |||
| // input_2_reshape(1,1,1,16) = Reshape(input_2(2)):float16:PI | |||
| // input_2_reshape(1,1,1,16) = Reshape(input_2(2)):float16:PI | |||
| class ReshapeMatmul : public ReshapeTensorMutator { | |||
| public: | |||
| explicit ReshapeMatmul() {} | |||
| @@ -383,7 +383,7 @@ class ReshapeMatmul : public ReshapeTensorMutator { | |||
| } | |||
| Stmt Mutate_(const Provide *op, const Stmt &s) { | |||
| static std::unordered_set<std::string> check_list = {"MatMul"}; | |||
| static std::unordered_set<std::string> check_list = {"MatMul", "BatchMatMul"}; | |||
| auto call = op->value.as<Call>(); | |||
| if (call == nullptr || check_list.find(call->name) == check_list.end()) { | |||
| return IRMutator::Mutate_(op, s); | |||
| @@ -468,9 +468,9 @@ class ReshapeMatmul : public ReshapeTensorMutator { | |||
| return orig_shape; | |||
| } | |||
| Array<Expr> InferShapeToFractalNz(const Array<Expr> &shape0, const Array<Expr> &shape1, | |||
| const Array<Expr> &shape_out, const Array<Expr> &shape_fractal, | |||
| const std::string &op_name, const Array<Expr> &shape_default) override { | |||
| Array<Expr> InferShapeToFractalNz(const Array<Expr> &shape0, const Array<Expr> &shape1, const Array<Expr> &shape_out, | |||
| const Array<Expr> &shape_fractal, const std::string &op_name, | |||
| const Array<Expr> &shape_default) override { | |||
| auto dims = shape_out.size(); | |||
| auto batch = dims - 2; | |||
| Array<Expr> shape_new; | |||
| @@ -491,8 +491,8 @@ class ReshapeMatmul : public ReshapeTensorMutator { | |||
| shape_new.push_back(shape_fractal[shape_fractal.size() - 1]); | |||
| } | |||
| } else { | |||
| LOG(FATAL) << "[" << op_name << "] " << shape_fractal << " (FRACTAL_NZ) and " << shape_default | |||
| << " (DefaultFormat) may need data format transformation for "; | |||
| LOG(FATAL) << "[" << op_name << "] " << shape_fractal << " (FRACTAL_NZ) and " << shape_default | |||
| << " (DefaultFormat) may need data format transformation for "; | |||
| } | |||
| return shape_new; | |||
| } | |||
| @@ -512,9 +512,13 @@ class ReshapeMatmul : public ReshapeTensorMutator { | |||
| std::stack<bool> transpose_b; | |||
| void PadBias(Array<Expr> &shape_default) { | |||
| if (shape_default.size() != 1) { return; } | |||
| if (shape_default.size() != 1) { | |||
| return; | |||
| } | |||
| auto bias_length = (shape_default[0].as<IntImm>())->value; | |||
| if (bias_length % 16 == 0) { return; } | |||
| if (bias_length % 16 == 0) { | |||
| return; | |||
| } | |||
| int64_t pad_length = (bias_length / 16) * 16 + 16; | |||
| shape_default.Set(0, Expr(pad_length)); | |||
| LOG(INFO) << "Pad bias length from " << bias_length << " to " << pad_length; | |||