import torch cls_score_original = torch.rand(8,11,9).cuda() print(cls_score_original) cls_score = cls_score_original.permute(0, 2, 1).reshape(-1, 11).contiguous() print(cls_score) cls_score = cls_score.reshape(8,9,11).permute(0, 2, 1) print(cls_score) print(cls_score_original == cls_score)