|
|
|
@@ -132,6 +132,8 @@ void pass_level1(const torch::jit::Module& mod, const std::shared_ptr<torch::jit |
|
|
|
// sub_mod.dump(true, true, true); |
|
|
|
|
|
|
|
op->attrs["data"] = sub_mod.attr(name).toTensor(); |
|
|
|
op->outputs[0]->type = op->attrs["data"].type; |
|
|
|
op->outputs[0]->shape = op->attrs["data"].shape; |
|
|
|
} |
|
|
|
} |
|
|
|
else if (n->kind() == c10::prim::Constant) // || n->kind() == c10::prim::ListConstruct) |
|
|
|
|