|
|
|
@@ -167,6 +167,32 @@ static void solve_batch_index_forward(Operand* operand) |
|
|
|
solve_batch_index_backward(r); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (op->type == "torch.transpose") |
|
|
|
{ |
|
|
|
const int dim0 = op->params.at("dim0").i; |
|
|
|
const int dim1 = op->params.at("dim1").i; |
|
|
|
|
|
|
|
int batch_index_transposed = batch_index; |
|
|
|
if (dim0 == batch_index) |
|
|
|
{ |
|
|
|
batch_index_transposed = dim1; |
|
|
|
} |
|
|
|
else if (dim1 == batch_index) |
|
|
|
{ |
|
|
|
batch_index_transposed = dim0; |
|
|
|
} |
|
|
|
|
|
|
|
for (Operand* r : op->outputs) |
|
|
|
{ |
|
|
|
if (r->params.find("__batch_index") != r->params.end()) |
|
|
|
continue; |
|
|
|
|
|
|
|
r->params["__batch_index"] = batch_index_transposed; |
|
|
|
|
|
|
|
solve_batch_index_forward(r); |
|
|
|
solve_batch_index_backward(r); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") |
|
|
|
{ |
|
|
|
if (op->params.find("shape") == op->params.end()) |
|
|
|
@@ -241,6 +267,32 @@ static void solve_batch_index_backward(Operand* operand) |
|
|
|
solve_batch_index_forward(r); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (op->type == "torch.transpose") |
|
|
|
{ |
|
|
|
const int dim0 = op->params.at("dim0").i; |
|
|
|
const int dim1 = op->params.at("dim1").i; |
|
|
|
|
|
|
|
int batch_index_transposed = batch_index; |
|
|
|
if (dim0 == batch_index) |
|
|
|
{ |
|
|
|
batch_index_transposed = dim1; |
|
|
|
} |
|
|
|
else if (dim1 == batch_index) |
|
|
|
{ |
|
|
|
batch_index_transposed = dim0; |
|
|
|
} |
|
|
|
|
|
|
|
for (Operand* r : op->inputs) |
|
|
|
{ |
|
|
|
if (r->params.find("__batch_index") != r->params.end()) |
|
|
|
continue; |
|
|
|
|
|
|
|
r->params["__batch_index"] = batch_index_transposed; |
|
|
|
|
|
|
|
solve_batch_index_backward(r); |
|
|
|
solve_batch_index_forward(r); |
|
|
|
} |
|
|
|
} |
|
|
|
else if (op->type == "Tensor.reshape" || op->type == "Tensor.view") |
|
|
|
{ |
|
|
|
if (op->params.find("shape") == op->params.end()) |
|
|
|
|