Browse Source

support BatchMatMul

pull/56/head
lingyunli63 5 years ago
parent
commit
509e12c346
5 changed files with 143 additions and 37 deletions
  1. +54
    -22
      src/composite/composite_topi.cc
  2. +3
    -0
      src/composite/optimize/optimize.cc
  3. +39
    -0
      src/composite/optimize/rename_matmul.cc
  4. +28
    -0
      src/composite/optimize/rename_matmul.h
  5. +19
    -15
      src/composite/optimize/reshape_tensor.cc

+ 54
- 22
src/composite/composite_topi.cc View File

@@ -624,7 +624,7 @@ TVM_REGISTER_GLOBAL("BroadcastTo").set_body([](TVMArgs args, TVMRetValue *rv) {
}
});

TVM_REGISTER_GLOBAL("BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) {
TVM_REGISTER_GLOBAL("cuda_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_GE(args.size(), 2);
auto inputs = args[0].operator Array<NodeRef>();
auto attrs = args[1].operator OpAttr();
@@ -718,7 +718,7 @@ TVM_REGISTER_GLOBAL("BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) {
});

// only support fractal_zN: [ko mo mi ki] * [no ko ki ni] = [no mo mi ni]
TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv) {
TVM_REGISTER_GLOBAL("aicore_BatchMatMul").set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_GE(args.size(), 2);
auto attrs = args[1].operator OpAttr();
CHECK(attrs.count("transpose_a"));
@@ -743,7 +743,7 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv)
auto left_shape = left_matrix->shape;
auto right_shape = right_matrix->shape;
CHECK_EQ(left_shape.size(), right_shape.size());
CHECK_EQ(left_shape.size(), 4);
CHECK_GE(left_shape.size(), 4);

auto type_checker = [](const Tensor &input_data, const std::string name, const air::DataType type) {
if (input_data->dtype != type) {
@@ -757,26 +757,33 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv)
Array<Expr> output_shape;
Array<Expr> k;
auto compute_mnk = [&output_shape, &k, &left_shape, &right_shape, transpose_a, transpose_b]() {
size_t dim = left_shape.size();
Expr mo, mi, no, ni, ko, ki;
if (transpose_a) {
mo = left_shape[0];
ko = left_shape[1];
ki = left_shape[2];
mi = left_shape[3];
mo = left_shape[dim - 4];
ko = left_shape[dim - 3];
ki = left_shape[dim - 2];
mi = left_shape[dim - 1];
} else {
ko = left_shape[0];
mo = left_shape[1];
mi = left_shape[2];
ki = left_shape[3];
ko = left_shape[dim - 4];
mo = left_shape[dim - 3];
mi = left_shape[dim - 2];
ki = left_shape[dim - 1];
}
if (transpose_b) {
no = right_shape[1];
ni = right_shape[2];
no = right_shape[dim - 3];
ni = right_shape[dim - 2];
} else {
no = right_shape[0];
ni = right_shape[3];
no = right_shape[dim - 4];
ni = right_shape[dim - 1];
}
output_shape = {no, mo, mi, ni};
for (size_t i = 0; i < dim - 4; ++i) {
output_shape.push_back(left_shape[i]);
}
output_shape.push_back(no);
output_shape.push_back(mo);
output_shape.push_back(mi);
output_shape.push_back(ni);
k = {ko, ki};
};

@@ -795,22 +802,47 @@ TVM_REGISTER_GLOBAL("aicore_MatMul").set_body([](TVMArgs args, TVMRetValue *rv)
IterVar reduce_ki = air::reduce_axis(Range(0, k[1]), "ki");
Array<IterVar> reduces = {reduce_ko, reduce_ki};

auto fcompute = [&left_matrix, &right_matrix, &transpose_a, &transpose_b, &reduces, &output_shape,
auto fcompute = [&left_matrix, &right_matrix, &transpose_a, &transpose_b, &reduces,
&Mmad](const Array<Var> &indices) {
Array<Expr> left_indice = {reduces[0], indices[1], indices[2], reduces[1]};
Array<Expr> right_indice = {indices[0], reduces[0], reduces[1], indices[3]};
size_t dim = indices.size();
Array<Expr> left_indice;
for (size_t i = 0; i < dim - 4; ++i) {
left_indice.push_back(indices[i]);
}
if (transpose_a) {
left_indice = {indices[1], reduces[0], reduces[1], indices[2]};
left_indice.push_back(indices[dim - 3]);
left_indice.push_back(reduces[0]);
left_indice.push_back(reduces[1]);
left_indice.push_back(indices[dim - 2]);
} else {
left_indice.push_back(reduces[0]);
left_indice.push_back(indices[dim - 3]);
left_indice.push_back(indices[dim - 2]);
left_indice.push_back(reduces[1]);
}

Array<Expr> right_indice;
for (size_t i = 0; i < dim - 4; ++i) {
right_indice.push_back(indices[i]);
}
if (transpose_b) {
right_indice = {reduces[0], indices[0], indices[3], reduces[1]};
right_indice.push_back(reduces[0]);
right_indice.push_back(indices[dim - 4]);
right_indice.push_back(indices[dim - 1]);
right_indice.push_back(reduces[1]);
} else {
right_indice.push_back(indices[dim - 4]);
right_indice.push_back(reduces[0]);
right_indice.push_back(reduces[1]);
right_indice.push_back(indices[dim - 1]);
}

Expr res = Mmad(Cast::make(Float(32), left_matrix(left_indice) * right_matrix(right_indice)), reduces);
return res;
};

// set output name
auto name = "T_matmul_" + left_matrix->op->name + "_" + right_matrix->op->name;
auto name = "T_batchmatmul_" + left_matrix->op->name + "_" + right_matrix->op->name;

// set compute attrs
auto set_compute_attrs_zN = [&left_matrix, &right_matrix, &inputs, transpose_a, transpose_b, attrs]() {


+ 3
- 0
src/composite/optimize/optimize.cc View File

@@ -15,6 +15,7 @@
*/
#include "composite/optimize/optimize.h"
#include <memory>
#include "composite/optimize/rename_matmul.h"
#include "composite/optimize/reshape_tensor.h"
#include "composite/optimize/elim_transform_op.h"
#include "composite/optimize/inplace_assign_mutator.h"
@@ -51,6 +52,8 @@ Stmt Optimize(Stmt &s, BuildInfo &info) {
if (info.opt.target == "aicore") {
pm.RegisterPass(std::make_shared<TypeCastInserter>());
}
// rename MatMul to BatchMatMul
pm.RegisterPass(std::make_shared<RenameMatmul>());
s = pm.Run(s);
return s;
}


+ 39
- 0
src/composite/optimize/rename_matmul.cc View File

@@ -0,0 +1,39 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "composite/optimize/rename_matmul.h"

namespace akg {
// rename MatMul to BatchMatMul
class RenameMatmulMutator : public IRMutator {
public:
explicit RenameMatmulMutator() {}
~RenameMatmulMutator() override = default;

Stmt Mutate_(const Provide *op, const Stmt &s) {
auto call = op->value.as<Call>();
if (call == nullptr || call->name != "MatMul") {
return IRMutator::Mutate_(op, s);
}
return Provide::make(op->func, 0,
Call::make(op->value.type(), "BatchMatMul", call->args, Call::CallType::PureIntrinsic),
op->args);
}
};

Stmt RenameMatmul::Run(const Stmt &s) {
return RenameMatmulMutator().Mutate(s);
}
} // namespace akg

+ 28
- 0
src/composite/optimize/rename_matmul.h View File

@@ -0,0 +1,28 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_
#define COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_
#include "composite/optimize/optimize.h"

namespace akg {
class RenameMatmul : public CompositeOptPass {
public:
RenameMatmul() { pass_name_ = __FUNCTION__; }
~RenameMatmul() = default;
Stmt Run(const Stmt &s) override;
};
} // namespace akg
#endif // COMPOSITE_OPTIMIZE_RENAME_MATMUL_H_

+ 19
- 15
src/composite/optimize/reshape_tensor.cc View File

@@ -52,7 +52,8 @@ class ReshapeTensorMutator : public IRMutator {
}

Stmt Mutate_(const Provide *op, const Stmt &s) {
static std::unordered_set<std::string> check_list = {"TensorAdd", "Add", "RealDiv", "Mul", "Minimum", "Maximum", "Sub"};
static std::unordered_set<std::string> check_list = {"TensorAdd", "Add", "RealDiv", "Mul",
"Minimum", "Maximum", "Sub"};
auto call = op->value.as<Call>();
if (call == nullptr || check_list.find(call->name) == check_list.end()) {
return IRMutator::Mutate_(op, s);
@@ -212,7 +213,7 @@ class ReshapeTensorMutator : public IRMutator {
}
auto call = op->value.as<Call>();
return Provide::make(op->func, 0, Call::make(op->value.type(), call->name, input, Call::CallType::PureIntrinsic),
op->args);
op->args);
}

Stmt ModifyAttrMap(const AttrStmt *op, const Stmt &stmt, const Map<std::string, NodeRef> &attr_map) {
@@ -252,9 +253,9 @@ class ReshapeTensorMutator : public IRMutator {
for (const auto &it : reshape_) {
auto arg =
Call::make(it.first->dtype, it.first->op->name, it.first->shape, Call::CallType::Halide, it.first->op);
auto reshape_stmt = Provide::make(
it.second->op, 0, Call::make(it.first->dtype, "Reshape", {arg}, Call::CallType::PureIntrinsic),
it.second->shape);
auto reshape_stmt =
Provide::make(it.second->op, 0, Call::make(it.first->dtype, "Reshape", {arg}, Call::CallType::PureIntrinsic),
it.second->shape);
Map<std::string, NodeRef> attrs;
attrs.Set("shape", it.second->shape);
auto reshape_attr = AttrStmt::make(attrs, "attrs", Expr(1), reshape_stmt);
@@ -353,12 +354,11 @@ class ReshapeTensorMutator : public IRMutator {
}
return std::make_tuple(shape_long, shape_tmp, shape_out);
}

};

// When Matmul has DefaultFormat bias, reshape bias to FRACTAL_NZ format
// If bias need pad, do pad as
// input_2_reshape(1,1,1,16) = Reshape(input_2(2)):float16:PI
// input_2_reshape(1,1,1,16) = Reshape(input_2(2)):float16:PI
class ReshapeMatmul : public ReshapeTensorMutator {
public:
explicit ReshapeMatmul() {}
@@ -383,7 +383,7 @@ class ReshapeMatmul : public ReshapeTensorMutator {
}

Stmt Mutate_(const Provide *op, const Stmt &s) {
static std::unordered_set<std::string> check_list = {"MatMul"};
static std::unordered_set<std::string> check_list = {"MatMul", "BatchMatMul"};
auto call = op->value.as<Call>();
if (call == nullptr || check_list.find(call->name) == check_list.end()) {
return IRMutator::Mutate_(op, s);
@@ -468,9 +468,9 @@ class ReshapeMatmul : public ReshapeTensorMutator {
return orig_shape;
}

Array<Expr> InferShapeToFractalNz(const Array<Expr> &shape0, const Array<Expr> &shape1,
const Array<Expr> &shape_out, const Array<Expr> &shape_fractal,
const std::string &op_name, const Array<Expr> &shape_default) override {
Array<Expr> InferShapeToFractalNz(const Array<Expr> &shape0, const Array<Expr> &shape1, const Array<Expr> &shape_out,
const Array<Expr> &shape_fractal, const std::string &op_name,
const Array<Expr> &shape_default) override {
auto dims = shape_out.size();
auto batch = dims - 2;
Array<Expr> shape_new;
@@ -491,8 +491,8 @@ class ReshapeMatmul : public ReshapeTensorMutator {
shape_new.push_back(shape_fractal[shape_fractal.size() - 1]);
}
} else {
LOG(FATAL) << "[" << op_name << "] " << shape_fractal << " (FRACTAL_NZ) and " << shape_default
<< " (DefaultFormat) may need data format transformation for ";
LOG(FATAL) << "[" << op_name << "] " << shape_fractal << " (FRACTAL_NZ) and " << shape_default
<< " (DefaultFormat) may need data format transformation for ";
}
return shape_new;
}
@@ -512,9 +512,13 @@ class ReshapeMatmul : public ReshapeTensorMutator {
std::stack<bool> transpose_b;

void PadBias(Array<Expr> &shape_default) {
if (shape_default.size() != 1) { return; }
if (shape_default.size() != 1) {
return;
}
auto bias_length = (shape_default[0].as<IntImm>())->value;
if (bias_length % 16 == 0) { return; }
if (bias_length % 16 == 0) {
return;
}
int64_t pad_length = (bias_length / 16) * 16 + 16;
shape_default.Set(0, Expr(pad_length));
LOG(INFO) << "Pad bias length from " << bias_length << " to " << pad_length;


Loading…
Cancel
Save