|
|
|
@@ -23,9 +23,6 @@ from mindspore.common.tensor import Tensor |
|
|
|
from mindspore.common.parameter import Parameter |
|
|
|
|
|
|
|
|
|
|
|
# from mindspore.ops.primitive import constexpr |
|
|
|
# import IPython |
|
|
|
|
|
|
|
class MultiheadAttention(nn.Cell): |
|
|
|
""" |
|
|
|
Apply multi-headed attention from "from_tensor" to "to_tensor". |
|
|
|
@@ -85,7 +82,7 @@ class MultiheadAttention(nn.Cell): |
|
|
|
self.shape_q_2d = (-1, q_tensor_width) |
|
|
|
self.shape_k_2d = (-1, k_tensor_width) |
|
|
|
self.shape_v_2d = (-1, v_tensor_width) |
|
|
|
self.hidden_width = hidden_width |
|
|
|
self.hidden_width = int(hidden_width) |
|
|
|
# units = num_attention_heads * self.size_per_head |
|
|
|
if self.same_dim: |
|
|
|
self.in_proj_layer = \ |
|
|
|
@@ -132,46 +129,49 @@ class MultiheadAttention(nn.Cell): |
|
|
|
self.softmax_cast = P.Cast() |
|
|
|
self.matmul_dense = P.MatMul(transpose_b=True) |
|
|
|
self.split = P.Split(0, 3) |
|
|
|
self.equal = P.Equal() |
|
|
|
self.shape = P.Shape() |
|
|
|
|
|
|
|
def construct(self, tensor_q, tensor_k, tensor_v, batch_size, seq_length, attention_mask=None): |
|
|
|
def construct(self, tensor_q, tensor_k, tensor_v, attention_mask=None): |
|
|
|
"""Apply multihead attention.""" |
|
|
|
self.batch_size = batch_size |
|
|
|
shape_qkv = (self.batch_size, -1, |
|
|
|
batch_size, seq_length, _ = self.shape(tensor_q) |
|
|
|
shape_qkv = (batch_size, -1, |
|
|
|
self.num_attention_heads, self.size_per_head) |
|
|
|
shape_linear = (self.batch_size * seq_length, |
|
|
|
shape_linear = (batch_size * seq_length, |
|
|
|
self.num_attention_heads * self.size_per_head) |
|
|
|
if self.do_return_2d_tensor: |
|
|
|
shape_return = (self.batch_size * seq_length, |
|
|
|
if self.do_return_2d_tensor is True: |
|
|
|
shape_return = (batch_size * seq_length, |
|
|
|
self.num_attention_heads * self.size_per_head) |
|
|
|
if seq_length == -1: |
|
|
|
shape_return = (-1, self.num_attention_heads * |
|
|
|
self.size_per_head) |
|
|
|
else: |
|
|
|
shape_return = (self.batch_size, seq_length, |
|
|
|
shape_return = (batch_size, seq_length, |
|
|
|
self.num_attention_heads * self.size_per_head) |
|
|
|
|
|
|
|
tensor_q_2d = self.reshape(tensor_q, self.shape_q_2d) |
|
|
|
tensor_k_2d = self.reshape(tensor_k, self.shape_k_2d) |
|
|
|
tensor_v_2d = self.reshape(tensor_v, self.shape_v_2d) |
|
|
|
|
|
|
|
if P.Equal()(tensor_q_2d, tensor_v_2d)[0][0]: |
|
|
|
if self.equal(tensor_q_2d, tensor_v_2d) is True: |
|
|
|
x = self.matmul_dense(self.in_proj_layer, tensor_q_2d) |
|
|
|
query_out, key_out, value_out = self.split(x) |
|
|
|
|
|
|
|
elif self.same_dim: |
|
|
|
_start = int(0) |
|
|
|
_end = int(self.hidden_width) |
|
|
|
elif self.same_dim is True: |
|
|
|
_start = 0 |
|
|
|
_end = self.hidden_width |
|
|
|
_w = self.in_proj_layer[_start:_end, :] |
|
|
|
# _b = None |
|
|
|
query_out = self.matmul_dense(_w, tensor_q_2d) |
|
|
|
|
|
|
|
_start = int(self.hidden_width) |
|
|
|
_end = int(self.hidden_width * 2) |
|
|
|
_start = self.hidden_width |
|
|
|
_end = self.hidden_width * 2 |
|
|
|
_w = self.in_proj_layer[_start:_end, :] |
|
|
|
# _b = None |
|
|
|
key_out = self.matmul_dense(_w, tensor_k_2d) |
|
|
|
|
|
|
|
_start = int(self.hidden_width * 2) |
|
|
|
_start = self.hidden_width * 2 |
|
|
|
|
|
|
|
_end = None |
|
|
|
_w = self.in_proj_layer[_start:] |
|
|
|
# _b = None |
|
|
|
@@ -247,7 +247,7 @@ class TransformerEncoderLayer(nn.Cell): |
|
|
|
permute_recover = (b, n, d) |
|
|
|
src2 = self.norm1(src) |
|
|
|
q = k = self.with_pos_embed(src2, pos) |
|
|
|
src2 = self.self_attn(q, k, src2, batch_size=b, seq_length=n) |
|
|
|
src2 = self.self_attn(q, k, src2) |
|
|
|
src = src + self.dropout1(src2) |
|
|
|
src2 = self.norm2(src) |
|
|
|
src2 = self.reshape(src2, permute_linear) |
|
|
|
@@ -301,13 +301,12 @@ class TransformerDecoderLayer(nn.Cell): |
|
|
|
permute_recover = (b, n, d) |
|
|
|
tgt2 = self.norm1(tgt) |
|
|
|
q = k = self.with_pos_embed(tgt2, query_pos) |
|
|
|
tgt2 = self.self_attn(q, k, tensor_v=tgt2, batch_size=b, seq_length=n) |
|
|
|
tgt2 = self.self_attn(q, k, tensor_v=tgt2) |
|
|
|
tgt = tgt + self.dropout1(tgt2) |
|
|
|
tgt2 = self.norm2(tgt) |
|
|
|
tgt2 = self.multihead_attn(tensor_q=self.with_pos_embed(tgt2, query_pos), |
|
|
|
tensor_k=self.with_pos_embed(memory, pos), |
|
|
|
tensor_v=memory, |
|
|
|
batch_size=b, seq_length=n) |
|
|
|
tensor_v=memory,) |
|
|
|
tgt = tgt + self.dropout2(tgt2) |
|
|
|
tgt2 = self.norm3(tgt) |
|
|
|
tgt2 = self.reshape(tgt2, permute_linear) |
|
|
|
@@ -393,6 +392,7 @@ class VisionTransformer(nn.Cell): |
|
|
|
num_layers, |
|
|
|
hidden_dim, |
|
|
|
num_queries, |
|
|
|
idx, |
|
|
|
positional_encoding_type="learned", |
|
|
|
dropout_rate=0, |
|
|
|
norm=False, |
|
|
|
@@ -422,7 +422,7 @@ class VisionTransformer(nn.Cell): |
|
|
|
self.no_pos = no_pos |
|
|
|
|
|
|
|
self.unf = _unfold_(patch_dim) |
|
|
|
self.fold = _fold_(patch_dim) |
|
|
|
self.fold = _fold_(patch_dim, output_shape=(img_dim, img_dim)) |
|
|
|
|
|
|
|
if self.mlp is not True: |
|
|
|
self.linear_encoding = nn.Dense( |
|
|
|
@@ -437,7 +437,6 @@ class VisionTransformer(nn.Cell): |
|
|
|
|
|
|
|
self.query_embed = nn.Embedding( |
|
|
|
num_queries, embedding_dim * self.seq_length) |
|
|
|
|
|
|
|
encoder_layer = TransformerEncoderLayer( |
|
|
|
embedding_dim, num_heads, hidden_dim, dropout_rate) |
|
|
|
self.encoder = TransformerEncoder(encoder_layer, num_layers) |
|
|
|
@@ -455,30 +454,31 @@ class VisionTransformer(nn.Cell): |
|
|
|
) |
|
|
|
|
|
|
|
self.dropout_layer1 = nn.Dropout(1. - dropout_rate) |
|
|
|
|
|
|
|
def construct(self, x, query_idx): |
|
|
|
self.query_idx = idx |
|
|
|
self.query_idx_tensor = Tensor(idx, mstype.int32) |
|
|
|
def construct(self, x): |
|
|
|
"""ipt""" |
|
|
|
B, _, _, _ = x.shape |
|
|
|
x = self.unf(x) |
|
|
|
B, N, _ = x.shape |
|
|
|
|
|
|
|
if self.mlp is not True: |
|
|
|
x = self.reshape(x, (int(B * N), -1)) |
|
|
|
x = self.reshape(x, (B * N, -1)) |
|
|
|
x = self.dropout_layer1(self.linear_encoding(x)) + x |
|
|
|
x = self.reshape(x, (B, N, -1)) |
|
|
|
query_embed = self.tile( |
|
|
|
self.reshape(self.query_embed.embedding_table[int( |
|
|
|
query_idx)], (1, self.seq_length, self.embedding_dim)), |
|
|
|
self.reshape(self.query_embed(self.query_idx_tensor), (1, self.seq_length, self.embedding_dim)), |
|
|
|
(B, 1, 1)) |
|
|
|
|
|
|
|
if not self.no_pos: |
|
|
|
pos = self.position_encoding(x) |
|
|
|
|
|
|
|
x = self.encoder(x + pos) |
|
|
|
x = self.encoder(x + pos) |
|
|
|
else: |
|
|
|
x = self.encoder(x) |
|
|
|
x = self.decoder(x, x, query_pos=query_embed) |
|
|
|
|
|
|
|
if self.mlp is not True: |
|
|
|
x = self.reshape(x, (int(B * N), -1)) |
|
|
|
x = self.reshape(x, (B * N, -1)) |
|
|
|
x = self.mlp_head(x) + x |
|
|
|
x = self.reshape(x, (B, N, -1)) |
|
|
|
x = self.fold(x) |
|
|
|
@@ -542,9 +542,9 @@ class ResBlock(nn.Cell): |
|
|
|
def _pixelsf_(x, scale): |
|
|
|
"""ipt""" |
|
|
|
N, C, iH, iW = x.shape |
|
|
|
oH = int(iH * scale) |
|
|
|
oW = int(iW * scale) |
|
|
|
oC = int(C // (scale ** 2)) |
|
|
|
oH = iH * scale |
|
|
|
oW = iW * scale |
|
|
|
oC = C // (scale ** 2) |
|
|
|
|
|
|
|
output = P.Reshape()(x, (N, oC, scale, scale, iH, iW)) |
|
|
|
|
|
|
|
@@ -565,11 +565,12 @@ class SmallUpSampler(nn.Cell): |
|
|
|
self.conv = conv(n_feats, upsize * upsize * n_feats, 3, bias) |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.upsize = upsize |
|
|
|
self.pixelsf = _pixelsf_ |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
"""ipt""" |
|
|
|
x = self.conv(x) |
|
|
|
output = _pixelsf_(x, self.upsize) |
|
|
|
output = self.pixelsf(x, self.upsize) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
@@ -628,7 +629,8 @@ class IPT(nn.Cell): |
|
|
|
dropout_rate=args.dropout_rate, |
|
|
|
mlp=args.no_mlp, |
|
|
|
pos_every=args.pos_every, |
|
|
|
no_pos=args.no_pos) |
|
|
|
no_pos=args.no_pos, |
|
|
|
idx=self.scale_idx) |
|
|
|
|
|
|
|
self.tail = nn.CellList([ |
|
|
|
nn.SequentialCell( |
|
|
|
@@ -645,7 +647,7 @@ class IPT(nn.Cell): |
|
|
|
"""ipt""" |
|
|
|
x = self.sub_mean(x) |
|
|
|
x = self.head[self.scale_idx](x) |
|
|
|
res = self.body(x, self.scale_idx) |
|
|
|
res = self.body(x) |
|
|
|
res += x |
|
|
|
x = self.tail[self.scale_idx](res) |
|
|
|
x = self.add_mean(x) |
|
|
|
@@ -654,30 +656,43 @@ class IPT(nn.Cell): |
|
|
|
|
|
|
|
def set_scale(self, scale_idx): |
|
|
|
"""ipt""" |
|
|
|
self.body.query_idx = scale_idx |
|
|
|
self.scale_idx = scale_idx |
|
|
|
|
|
|
|
def infrc(self, x): |
|
|
|
"""ipt""" |
|
|
|
forward_function = self.forward_chop_new |
|
|
|
|
|
|
|
return forward_function(x) |
|
|
|
class IPT_post(): |
|
|
|
"""ipt""" |
|
|
|
def __init__(self, model, args): |
|
|
|
super(IPT_post, self).__init__() |
|
|
|
self.model = model |
|
|
|
self.args = args |
|
|
|
self.scale_idx = 0 |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.tile = P.Tile() |
|
|
|
self.transpose = P.Transpose() |
|
|
|
self.cc_0 = P.Concat(axis=0) |
|
|
|
self.cc_2 = P.Concat(axis=2) |
|
|
|
self.cc_3 = P.Concat(axis=3) |
|
|
|
|
|
|
|
def forward_chop_new(self, x, shave=12, batchsize=64): |
|
|
|
def set_scale(self, scale_idx): |
|
|
|
"""ipt""" |
|
|
|
self.body.query_idx = scale_idx |
|
|
|
self.scale_idx = scale_idx |
|
|
|
|
|
|
|
def forward(self, x, shave=12, batchsize=64): |
|
|
|
"""ipt""" |
|
|
|
h, w = x.shape[-2:] |
|
|
|
padsize = int(self.args.patch_size) |
|
|
|
shave = int(self.args.patch_size / 4) |
|
|
|
scale = self.args.scale[self.scale_idx] |
|
|
|
|
|
|
|
h_cut = (h - padsize) % (padsize - shave) |
|
|
|
w_cut = (w - padsize) % (padsize - shave) |
|
|
|
|
|
|
|
unf_1 = _stride_unfold_(padsize, stride=padsize - shave) |
|
|
|
x_unfold = unf_1(x) |
|
|
|
x_unfold = unf_1.compute(x) |
|
|
|
x_unfold = self.transpose(x_unfold, (1, 0, 2)) # transpose(0,2) |
|
|
|
|
|
|
|
x_hw_cut = x[:, :, (h - padsize):, (w - padsize):] |
|
|
|
y_hw_cut = self.construct(x_hw_cut) |
|
|
|
y_hw_cut = self.model(x_hw_cut) |
|
|
|
|
|
|
|
x_h_cut = x[:, :, (h - padsize):, :] |
|
|
|
x_w_cut = x[:, :, :, (w - padsize):] |
|
|
|
@@ -696,66 +711,71 @@ class IPT(nn.Cell): |
|
|
|
x_unfold, (x_unfold.shape[0], -1, padsize, padsize)) |
|
|
|
x_range = x_unfold.shape[0] // batchsize + \ |
|
|
|
(x_unfold.shape[0] % batchsize != 0) |
|
|
|
|
|
|
|
cc_0 = P.Concat(axis=0) |
|
|
|
for i in range(x_range): |
|
|
|
if i == 0: |
|
|
|
y_unfold = self.construct( |
|
|
|
y_unfold = self.model( |
|
|
|
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) |
|
|
|
else: |
|
|
|
y_unfold = cc_0((y_unfold, self.construct( |
|
|
|
y_unfold = self.cc_0((y_unfold, self.model( |
|
|
|
x_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) |
|
|
|
y_unf_shape_0 = y_unfold.shape[0] |
|
|
|
fold_1 = \ |
|
|
|
_stride_fold_(padsize * scale, output_shape=((h - h_cut) * scale, (w - w_cut) * scale), |
|
|
|
stride=padsize * scale - shave * scale) |
|
|
|
y = fold_1(self.transpose(self.reshape( |
|
|
|
y = fold_1.compute(self.transpose(self.reshape( |
|
|
|
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) |
|
|
|
cc_2 = P.Concat(axis=2) |
|
|
|
cc_3 = P.Concat(axis=3) |
|
|
|
y = cc_2((y_h_top, y[:, :, padsize * scale:, :])) |
|
|
|
y = cc_3((y_w_top, y[:, :, :, padsize * scale:])) |
|
|
|
if y[:, :, padsize * scale:, :].shape[2] == 0: |
|
|
|
y = y_h_top |
|
|
|
else: |
|
|
|
y = self.cc_2((y_h_top, y[:, :, padsize * scale:, :])) |
|
|
|
if y[:, :, :, padsize * scale:].shape[3] == 0: |
|
|
|
y = y_w_top |
|
|
|
else: |
|
|
|
y = self.cc_3((y_w_top, y[:, :, :, padsize * scale:])) |
|
|
|
y_unfold = y_unfold[:, :, int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale), |
|
|
|
int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)] |
|
|
|
fold_2 = _stride_fold_(padsize * scale - shave * scale, |
|
|
|
output_shape=((h - h_cut - shave) * |
|
|
|
scale, (w - w_cut - shave) * scale), |
|
|
|
stride=padsize * scale - shave * scale) |
|
|
|
y_inter = fold_2(self.transpose(self.reshape( |
|
|
|
y_inter = fold_2.compute(self.transpose(self.reshape( |
|
|
|
y_unfold, (y_unf_shape_0, -1, 1)), (2, 0, 1))) |
|
|
|
y = cc_3((cc_3((y[:, :, :, :int(shave / 2 * scale)], cc_2((cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)), y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])))), y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long |
|
|
|
y = cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], |
|
|
|
y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) |
|
|
|
y_w_cat = cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :], |
|
|
|
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) |
|
|
|
y = cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)], |
|
|
|
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):])) |
|
|
|
concat1 = self.cc_2((y[:, :, :int(shave / 2 * scale), int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)], y_inter)) #pylint: disable=line-too-long |
|
|
|
concat2 = self.cc_2((concat1, y[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)])) #pylint: disable=line-too-long |
|
|
|
concat3 = self.cc_3((y[:, :, :, :int(shave / 2 * scale)], concat2)) |
|
|
|
y = self.cc_3((concat3, y[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) #pylint: disable=line-too-long |
|
|
|
y = self.cc_2((y[:, :, :y.shape[2] - int((padsize - h_cut) / 2 * scale), :], y_h_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) #pylint: disable=line-too-long |
|
|
|
|
|
|
|
y_w_cat = self.cc_2((y_w_cut[:, :, :y_w_cut.shape[2] - int((padsize - h_cut) / 2 * scale), :], |
|
|
|
y_hw_cut[:, :, int((padsize - h_cut) / 2 * scale + 0.5):, :])) |
|
|
|
y = self.cc_3((y[:, :, :, :y.shape[3] - int((padsize - w_cut) / 2 * scale)], |
|
|
|
y_w_cat[:, :, :, int((padsize - w_cut) / 2 * scale + 0.5):])) |
|
|
|
|
|
|
|
return y |
|
|
|
|
|
|
|
def cut_h_new(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): |
|
|
|
"""ipt""" |
|
|
|
unf_1 = _stride_unfold_(padsize, stride=padsize - shave) |
|
|
|
x_h_cut_unfold = unf_1(x_h_cut) |
|
|
|
x_h_cut_unfold = unf_1.compute(x_h_cut) |
|
|
|
x_h_cut_unfold = self.transpose(x_h_cut_unfold, (1, 0, 2)) |
|
|
|
|
|
|
|
x_h_cut_unfold = self.reshape( |
|
|
|
x_h_cut_unfold, (x_h_cut_unfold.shape[0], -1, padsize, padsize)) |
|
|
|
x_range = x_h_cut_unfold.shape[0] // batchsize + \ |
|
|
|
(x_h_cut_unfold.shape[0] % batchsize != 0) |
|
|
|
cc_0 = P.Concat(axis=0) |
|
|
|
for i in range(x_range): |
|
|
|
if i == 0: |
|
|
|
y_h_cut_unfold = self.construct( |
|
|
|
y_h_cut_unfold = self.model( |
|
|
|
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) |
|
|
|
else: |
|
|
|
y_h_cut_unfold = \ |
|
|
|
cc_0((y_h_cut_unfold, self.construct( |
|
|
|
self.cc_0((y_h_cut_unfold, self.model( |
|
|
|
x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) |
|
|
|
y_h_cut_unfold_shape_0 = y_h_cut_unfold.shape[0] |
|
|
|
fold_1 = \ |
|
|
|
_stride_fold_(padsize * scale, output_shape=(padsize * scale, (w - w_cut) * scale), |
|
|
|
stride=padsize * scale - shave * scale) |
|
|
|
y_h_cut = fold_1(self.transpose(self.reshape( |
|
|
|
y_h_cut = fold_1.compute(self.transpose(self.reshape( |
|
|
|
y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) |
|
|
|
y_h_cut_unfold = y_h_cut_unfold[:, :, :, int( |
|
|
|
shave / 2 * scale):padsize * scale - int(shave / 2 * scale)] |
|
|
|
@@ -763,37 +783,35 @@ class IPT(nn.Cell): |
|
|
|
output_shape=(padsize * scale, |
|
|
|
(w - w_cut - shave) * scale), |
|
|
|
stride=padsize * scale - shave * scale) |
|
|
|
y_h_cut_inter = fold_2(self.transpose(self.reshape( |
|
|
|
y_h_cut_inter = fold_2.compute(self.transpose(self.reshape( |
|
|
|
y_h_cut_unfold, (y_h_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) |
|
|
|
cc_3 = P.Concat(axis=3) |
|
|
|
y_h_cut = cc_3((cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)], |
|
|
|
y_h_cut_inter)), y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) |
|
|
|
concat1 = self.cc_3((y_h_cut[:, :, :, :int(shave / 2 * scale)], y_h_cut_inter)) |
|
|
|
y_h_cut = self.cc_3((concat1, y_h_cut[:, :, :, (w - w_cut) * scale - int(shave / 2 * scale):])) |
|
|
|
return y_h_cut |
|
|
|
|
|
|
|
def cut_w_new(self, x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): |
|
|
|
"""ipt""" |
|
|
|
unf_1 = _stride_unfold_(padsize, stride=padsize - shave) |
|
|
|
x_w_cut_unfold = unf_1(x_w_cut) |
|
|
|
x_w_cut_unfold = unf_1.compute(x_w_cut) |
|
|
|
x_w_cut_unfold = self.transpose(x_w_cut_unfold, (1, 0, 2)) |
|
|
|
|
|
|
|
x_w_cut_unfold = self.reshape( |
|
|
|
x_w_cut_unfold, (x_w_cut_unfold.shape[0], -1, padsize, padsize)) |
|
|
|
x_range = x_w_cut_unfold.shape[0] // batchsize + \ |
|
|
|
(x_w_cut_unfold.shape[0] % batchsize != 0) |
|
|
|
cc_0 = P.Concat(axis=0) |
|
|
|
for i in range(x_range): |
|
|
|
if i == 0: |
|
|
|
y_w_cut_unfold = self.construct( |
|
|
|
y_w_cut_unfold = self.model( |
|
|
|
x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]) |
|
|
|
else: |
|
|
|
y_w_cut_unfold = cc_0((y_w_cut_unfold, |
|
|
|
self.construct(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) |
|
|
|
y_w_cut_unfold = self.cc_0((y_w_cut_unfold, |
|
|
|
self.model(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, :, :, :]))) |
|
|
|
y_w_cut_unfold_shape_0 = y_w_cut_unfold.shape[0] |
|
|
|
fold_1 = _stride_fold_(padsize * scale, |
|
|
|
output_shape=((h - h_cut) * scale, |
|
|
|
padsize * scale), |
|
|
|
stride=padsize * scale - shave * scale) |
|
|
|
y_w_cut = fold_1(self.transpose(self.reshape( |
|
|
|
y_w_cut = fold_1.compute(self.transpose(self.reshape( |
|
|
|
y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) |
|
|
|
y_w_cut_unfold = y_w_cut_unfold[:, :, int( |
|
|
|
shave / 2 * scale):padsize * scale - int(shave / 2 * scale), :] |
|
|
|
@@ -801,19 +819,18 @@ class IPT(nn.Cell): |
|
|
|
output_shape=((h - h_cut - shave) |
|
|
|
* scale, padsize * scale), |
|
|
|
stride=padsize * scale - shave * scale) |
|
|
|
y_w_cut_inter = fold_2(self.transpose(self.reshape( |
|
|
|
y_w_cut_inter = fold_2.compute(self.transpose(self.reshape( |
|
|
|
y_w_cut_unfold, (y_w_cut_unfold_shape_0, -1, 1)), (2, 0, 1))) |
|
|
|
cc_2 = P.Concat(axis=2) |
|
|
|
y_w_cut = cc_2((cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :], |
|
|
|
y_w_cut_inter)), y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :])) |
|
|
|
concat1 = self.cc_2((y_w_cut[:, :, :int(shave / 2 * scale), :], y_w_cut_inter)) |
|
|
|
y_w_cut = self.cc_2((concat1, y_w_cut[:, :, (h - h_cut) * scale - int(shave / 2 * scale):, :])) |
|
|
|
return y_w_cut |
|
|
|
|
|
|
|
class _stride_unfold_(): |
|
|
|
'''stride''' |
|
|
|
|
|
|
|
class _stride_unfold_(nn.Cell): |
|
|
|
"""ipt""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, kernel_size, stride=-1): |
|
|
|
def __init__(self, |
|
|
|
kernel_size, |
|
|
|
stride=-1): |
|
|
|
|
|
|
|
super(_stride_unfold_, self).__init__() |
|
|
|
if stride == -1: |
|
|
|
@@ -821,28 +838,24 @@ class _stride_unfold_(nn.Cell): |
|
|
|
else: |
|
|
|
self.stride = stride |
|
|
|
self.kernel_size = kernel_size |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.transpose = P.Transpose() |
|
|
|
|
|
|
|
self.unfold = _unfold_(kernel_size) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
"""ipt""" |
|
|
|
def compute(self, x): |
|
|
|
"""stride""" |
|
|
|
x = x.asnumpy() |
|
|
|
N, C, H, W = x.shape |
|
|
|
leftup_idx_x = [] |
|
|
|
leftup_idx_y = [] |
|
|
|
nh = int((H - self.kernel_size) / self.stride + 1) |
|
|
|
nw = int((W - self.kernel_size) / self.stride + 1) |
|
|
|
nh = (H - self.kernel_size) // self.stride + 1 |
|
|
|
nw = (W - self.kernel_size) // self.stride + 1 |
|
|
|
for i in range(nh): |
|
|
|
leftup_idx_x.append(i * self.stride) |
|
|
|
for i in range(nw): |
|
|
|
leftup_idx_y.append(i * self.stride) |
|
|
|
NumBlock_x = len(leftup_idx_x) |
|
|
|
NumBlock_y = len(leftup_idx_y) |
|
|
|
zeroslike = P.ZerosLike() |
|
|
|
cc_2 = P.Concat(axis=2) |
|
|
|
cc_3 = P.Concat(axis=3) |
|
|
|
unf_x = P.Zeros()((N, C, NumBlock_x * self.kernel_size, |
|
|
|
NumBlock_y * self.kernel_size), mstype.float32) |
|
|
|
unf_x = np.zeros((N, C, NumBlock_x * self.kernel_size, NumBlock_y * self.kernel_size), dtype=np.float32) |
|
|
|
N, C, H, W = unf_x.shape |
|
|
|
for i in range(NumBlock_x): |
|
|
|
for j in range(NumBlock_y): |
|
|
|
@@ -852,23 +865,28 @@ class _stride_unfold_(nn.Cell): |
|
|
|
org_j = leftup_idx_y[j] |
|
|
|
fills = x[:, :, org_i:org_i + self.kernel_size, |
|
|
|
org_j:org_j + self.kernel_size] |
|
|
|
unf_x += cc_3((cc_3((zeroslike(unf_x[:, :, :, :unf_j]), |
|
|
|
cc_2( |
|
|
|
(cc_2((zeroslike(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size]), fills)), |
|
|
|
zeroslike(unf_x[:, :, unf_i + self.kernel_size:, |
|
|
|
unf_j:unf_j + self.kernel_size]))))), |
|
|
|
zeroslike(unf_x[:, :, :, unf_j + self.kernel_size:]))) |
|
|
|
zeros2 = np.zeros(unf_x[:, :, :unf_i, unf_j:unf_j + self.kernel_size].shape) |
|
|
|
concat1 = np.concatenate((zeros2, fills), axis=2) |
|
|
|
zeros3 = np.zeros(unf_x[:, :, unf_i + self.kernel_size:, unf_j:unf_j + self.kernel_size].shape) |
|
|
|
concat2 = np.concatenate((concat1, zeros3), axis=2) |
|
|
|
zeros1 = np.zeros(unf_x[:, :, :, :unf_j].shape) |
|
|
|
concat3 = np.concatenate((zeros1, concat2), axis=3) |
|
|
|
zeros4 = np.zeros(unf_x[:, :, :, unf_j + self.kernel_size:].shape) |
|
|
|
concat4 = np.concatenate((concat3, zeros4), axis=3) |
|
|
|
unf_x += concat4 |
|
|
|
unf_x = Tensor(unf_x, mstype.float32) |
|
|
|
y = self.unfold(unf_x) |
|
|
|
return y |
|
|
|
|
|
|
|
class _stride_fold_(nn.Cell): |
|
|
|
"""ipt""" |
|
|
|
class _stride_fold_(): |
|
|
|
'''stride''' |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, kernel_size, output_shape=(-1, -1), stride=-1): |
|
|
|
def __init__(self, |
|
|
|
kernel_size, |
|
|
|
output_shape=(-1, -1), |
|
|
|
stride=-1): |
|
|
|
|
|
|
|
super(_stride_fold_, self).__init__() |
|
|
|
|
|
|
|
if isinstance(kernel_size, (list, tuple)): |
|
|
|
self.kernel_size = kernel_size |
|
|
|
else: |
|
|
|
@@ -880,66 +898,49 @@ class _stride_fold_(nn.Cell): |
|
|
|
self.stride = stride |
|
|
|
|
|
|
|
self.output_shape = output_shape |
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.transpose = P.Transpose() |
|
|
|
self.fold = _fold_(kernel_size) |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
"""ipt""" |
|
|
|
cc_2 = P.Concat(axis=2) |
|
|
|
cc_3 = P.Concat(axis=3) |
|
|
|
zeroslike = P.ZerosLike() |
|
|
|
if self.output_shape[0] == -1: |
|
|
|
large_x = self.fold(x) |
|
|
|
N, C, H, _ = large_x.shape |
|
|
|
leftup_idx = [] |
|
|
|
for i in range(0, H, self.kernel_size[0]): |
|
|
|
leftup_idx.append(i) |
|
|
|
NumBlock = len(leftup_idx) |
|
|
|
fold_x = P.Zeros()((N, C, (NumBlock - 1) * self.stride + self.kernel_size[0], |
|
|
|
(NumBlock - 1) * self.stride + self.kernel_size[0]), mstype.float32) |
|
|
|
|
|
|
|
for i in range(NumBlock): |
|
|
|
for j in range(NumBlock): |
|
|
|
fold_i = i * self.stride |
|
|
|
fold_j = j * self.stride |
|
|
|
org_i = leftup_idx[i] |
|
|
|
org_j = leftup_idx[j] |
|
|
|
fills = large_x[:, :, org_i:org_i + self.kernel_size[0], |
|
|
|
org_j:org_j + self.kernel_size[1]] |
|
|
|
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long |
|
|
|
y = fold_x |
|
|
|
else: |
|
|
|
NumBlock_x = int( |
|
|
|
(self.output_shape[0] - self.kernel_size[0]) / self.stride + 1) |
|
|
|
NumBlock_y = int( |
|
|
|
(self.output_shape[1] - self.kernel_size[1]) / self.stride + 1) |
|
|
|
large_shape = [NumBlock_x * self.kernel_size[0], |
|
|
|
NumBlock_y * self.kernel_size[1]] |
|
|
|
self.fold = _fold_(self.kernel_size, large_shape) |
|
|
|
large_x = self.fold(x) |
|
|
|
N, C, H, _ = large_x.shape |
|
|
|
leftup_idx_x = [] |
|
|
|
leftup_idx_y = [] |
|
|
|
for i in range(NumBlock_x): |
|
|
|
leftup_idx_x.append(i * self.kernel_size[0]) |
|
|
|
for i in range(NumBlock_y): |
|
|
|
leftup_idx_y.append(i * self.kernel_size[1]) |
|
|
|
fold_x = P.Zeros()((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], |
|
|
|
(NumBlock_y - 1) * self.stride + self.kernel_size[1]), mstype.float32) |
|
|
|
for i in range(NumBlock_x): |
|
|
|
for j in range(NumBlock_y): |
|
|
|
fold_i = i * self.stride |
|
|
|
fold_j = j * self.stride |
|
|
|
org_i = leftup_idx_x[i] |
|
|
|
org_j = leftup_idx_y[j] |
|
|
|
fills = large_x[:, :, org_i:org_i + self.kernel_size[0], |
|
|
|
org_j:org_j + self.kernel_size[1]] |
|
|
|
fold_x += cc_3((cc_3((zeroslike(fold_x[:, :, :, :fold_j]), cc_2((cc_2((zeroslike(fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]]), fills)), zeroslike(fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]]))))), zeroslike(fold_x[:, :, :, fold_j + self.kernel_size[1]:]))) #pylint: disable=line-too-long |
|
|
|
y = fold_x |
|
|
|
self.NumBlock_x = (self.output_shape[0] - self.kernel_size[0]) // self.stride + 1 |
|
|
|
self.NumBlock_y = (self.output_shape[1] - self.kernel_size[1]) // self.stride + 1 |
|
|
|
self.large_shape = [self.NumBlock_x * self.kernel_size[0], self.NumBlock_y * self.kernel_size[1]] |
|
|
|
self.fold = _fold_(self.kernel_size, self.large_shape) |
|
|
|
|
|
|
|
def compute(self, x): |
|
|
|
'''stride''' |
|
|
|
NumBlock_x = self.NumBlock_x |
|
|
|
NumBlock_y = self.NumBlock_y |
|
|
|
large_x = self.fold(x) |
|
|
|
large_x = large_x.asnumpy() |
|
|
|
N, C, _, _ = large_x.shape |
|
|
|
leftup_idx_x = [] |
|
|
|
leftup_idx_y = [] |
|
|
|
for i in range(NumBlock_x): |
|
|
|
leftup_idx_x.append(i * self.kernel_size[0]) |
|
|
|
for i in range(NumBlock_y): |
|
|
|
leftup_idx_y.append(i * self.kernel_size[1]) |
|
|
|
fold_x = np.zeros((N, C, (NumBlock_x - 1) * self.stride + self.kernel_size[0], (NumBlock_y - 1) * self.stride + self.kernel_size[1]), dtype=np.float32) #pylint: disable=line-too-long |
|
|
|
for i in range(NumBlock_x): |
|
|
|
for j in range(NumBlock_y): |
|
|
|
fold_i = i * self.stride |
|
|
|
fold_j = j * self.stride |
|
|
|
org_i = leftup_idx_x[i] |
|
|
|
org_j = leftup_idx_y[j] |
|
|
|
fills = large_x[:, :, org_i:org_i + self.kernel_size[0], org_j:org_j + self.kernel_size[1]] |
|
|
|
t2 = fold_x[:, :, :fold_i, fold_j:fold_j + self.kernel_size[1]] |
|
|
|
zeros2 = np.zeros(t2.shape) |
|
|
|
concat1 = np.concatenate((zeros2, fills), axis=2) |
|
|
|
t3 = fold_x[:, :, fold_i + self.kernel_size[0]:, fold_j:fold_j + self.kernel_size[1]] |
|
|
|
zeros3 = np.zeros(t3.shape) |
|
|
|
concat2 = np.concatenate((concat1, zeros3), axis=2) |
|
|
|
t1 = fold_x[:, :, :, :fold_j] |
|
|
|
zeros1 = np.zeros(t1.shape) |
|
|
|
concat3 = np.concatenate((zeros1, concat2), axis=3) |
|
|
|
t4 = fold_x[:, :, :, fold_j + self.kernel_size[1]:] |
|
|
|
zeros4 = np.zeros(t4.shape) |
|
|
|
concat4 = np.concatenate((concat3, zeros4), axis=3) |
|
|
|
fold_x += concat4 |
|
|
|
y = Tensor(fold_x, mstype.float32) |
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
class _unfold_(nn.Cell): |
|
|
|
"""ipt""" |
|
|
|
|
|
|
|
@@ -957,20 +958,16 @@ class _unfold_(nn.Cell): |
|
|
|
def construct(self, x): |
|
|
|
"""ipt""" |
|
|
|
N, C, H, W = x.shape |
|
|
|
numH = int(H / self.kernel_size) |
|
|
|
numW = int(W / self.kernel_size) |
|
|
|
numH = H // self.kernel_size |
|
|
|
numW = W // self.kernel_size |
|
|
|
if numH * self.kernel_size != H or numW * self.kernel_size != W: |
|
|
|
x = x[:, :, :numH * self.kernel_size, :, numW * self.kernel_size] |
|
|
|
output_img = self.reshape(x, (N, C, numH, self.kernel_size, W)) |
|
|
|
|
|
|
|
output_img = self.transpose(output_img, (0, 1, 2, 4, 3)) |
|
|
|
|
|
|
|
output_img = self.reshape(output_img, (N, C, int( |
|
|
|
numH * numW), self.kernel_size, self.kernel_size)) |
|
|
|
|
|
|
|
output_img = self.transpose(output_img, (0, 2, 1, 4, 3)) |
|
|
|
|
|
|
|
output_img = self.reshape(output_img, (N, int(numH * numW), -1)) |
|
|
|
output_img = self.reshape(output_img, (N, C, numH, -1, self.kernel_size, self.kernel_size)) |
|
|
|
output_img = self.transpose(output_img, (0, 2, 3, 1, 5, 4)) |
|
|
|
output_img = self.reshape(output_img, (N, numH * numW, -1)) |
|
|
|
return output_img |
|
|
|
|
|
|
|
|
|
|
|
@@ -994,22 +991,17 @@ class _fold_(nn.Cell): |
|
|
|
|
|
|
|
self.reshape = P.Reshape() |
|
|
|
self.transpose = P.Transpose() |
|
|
|
self.sqrt = P.Sqrt() |
|
|
|
self.cast = P.Cast() |
|
|
|
|
|
|
|
def construct(self, x): |
|
|
|
"""ipt""" |
|
|
|
N, C, L = x.shape |
|
|
|
org_C = int(L / self.kernel_size[0] / self.kernel_size[1]) |
|
|
|
if self.output_shape[0] == -1: |
|
|
|
numH = int(np.sqrt(C)) |
|
|
|
numW = int(np.sqrt(C)) |
|
|
|
org_H = int(numH * self.kernel_size[0]) |
|
|
|
org_W = org_H |
|
|
|
else: |
|
|
|
org_H = int(self.output_shape[0]) |
|
|
|
org_W = int(self.output_shape[1]) |
|
|
|
numH = int(org_H / self.kernel_size[0]) |
|
|
|
numW = int(org_W / self.kernel_size[1]) |
|
|
|
|
|
|
|
org_C = L // (self.kernel_size[0] * self.kernel_size[1]) |
|
|
|
org_H = self.output_shape[0] |
|
|
|
org_W = self.output_shape[1] |
|
|
|
numH = org_H // self.kernel_size[0] |
|
|
|
numW = org_W // self.kernel_size[1] |
|
|
|
output_img = self.reshape( |
|
|
|
x, (N, C, org_C, self.kernel_size[0], self.kernel_size[1])) |
|
|
|
|
|
|
|
|