|
|
|
@@ -55,12 +55,15 @@ pnnx.Output output 1 0 out |
|
|
|
|
|
|
|
op->params["copy"] = captured_params.at("copy"); |
|
|
|
|
|
|
|
if (captured_params.at("memory_format").i == 0) |
|
|
|
op->params["memory_format"] = "torch.contiguous_format"; |
|
|
|
if (captured_params.at("memory_format").i == 1) |
|
|
|
op->params["memory_format"] = "torch.preserve_format"; |
|
|
|
if (captured_params.at("memory_format").i == 2) |
|
|
|
op->params["memory_format"] = "torch.channels_last"; |
|
|
|
if (captured_params.at("memory_format").type == 2) |
|
|
|
{ |
|
|
|
if (captured_params.at("memory_format").i == 0) |
|
|
|
op->params["memory_format"] = "torch.contiguous_format"; |
|
|
|
if (captured_params.at("memory_format").i == 1) |
|
|
|
op->params["memory_format"] = "torch.preserve_format"; |
|
|
|
if (captured_params.at("memory_format").i == 2) |
|
|
|
op->params["memory_format"] = "torch.channels_last"; |
|
|
|
} |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
@@ -83,7 +86,29 @@ pnnx.Output output 1 0 out |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
class Tensor_to_2 : public Tensor_to |
|
|
|
{ |
|
|
|
public: |
|
|
|
const char* match_pattern_graph() const |
|
|
|
{ |
|
|
|
return R"PNNXIR(7767517 |
|
|
|
10 9 |
|
|
|
pnnx.Input input_0 0 1 input |
|
|
|
prim::Constant op_0 0 1 dtype value=%dtype |
|
|
|
prim::Constant op_1 0 1 layout value=* |
|
|
|
prim::Constant op_2 0 1 device value=* |
|
|
|
prim::Constant op_3 0 1 pin_memory value=* |
|
|
|
prim::Constant op_4 0 1 non_blocking value=* |
|
|
|
prim::Constant op_5 0 1 copy value=%copy |
|
|
|
prim::Constant op_6 0 1 memory_format value=%memory_format |
|
|
|
aten::to op_7 8 1 input dtype layout device pin_memory non_blocking copy memory_format out |
|
|
|
pnnx.Output output 1 0 out |
|
|
|
)PNNXIR"; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20) |
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20) |
|
|
|
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_2, 20) |
|
|
|
|
|
|
|
} // namespace pnnx |