GitOrigin-RevId: 5e2acd4052
tags/v1.0.0-rc1
| @@ -78,9 +78,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||||
| opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | ||||
| return opnode.outputs[0] | return opnode.outputs[0] | ||||
| def make_h2d(self, *, dtype, device): | |||||
| def make_h2d(self, *, dtype, device, shape=None, name=None): | |||||
| device = as_device(device).to_c() | device = as_device(device).to_c() | ||||
| return self._wrap(_imperative_rt.make_h2d(self, device, dtype)) | |||||
| return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) | |||||
| def dump(*args): | def dump(*args): | ||||
| @@ -51,6 +51,7 @@ class TensorInfo: | |||||
| "value_read", | "value_read", | ||||
| "device", | "device", | ||||
| "dtype", | "dtype", | ||||
| "shape", | |||||
| "bound_data", | "bound_data", | ||||
| # resources for execution | # resources for execution | ||||
| "varnode", | "varnode", | ||||
| @@ -107,8 +108,8 @@ class trace: | |||||
| self._active_tensors = weakref.WeakSet() | self._active_tensors = weakref.WeakSet() | ||||
| self._tensor_remaps = None | self._tensor_remaps = None | ||||
| self._inputs_to_restore = None | self._inputs_to_restore = None | ||||
| self._args_bindings = None | |||||
| self._kwargs_bindings = None | |||||
| self._arg_bindings = None | |||||
| self._kwarg_bindings = None | |||||
| self._output_bindings = None | self._output_bindings = None | ||||
| self._output_names = None | self._output_names = None | ||||
| @@ -329,9 +330,7 @@ class trace: | |||||
| links = () | links = () | ||||
| if self._capture_as_const: | if self._capture_as_const: | ||||
| for h in itertools.chain( | |||||
| self._args_bindings, self._kwargs_bindings.values() | |||||
| ): | |||||
| for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): | |||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| opnode = info.data_setter = G.InputNode( | opnode = info.data_setter = G.InputNode( | ||||
| device=info.device, dtype=info.dtype, graph=graph | device=info.device, dtype=info.dtype, graph=graph | ||||
| @@ -434,15 +433,19 @@ class trace: | |||||
| h2v = {} | h2v = {} | ||||
| graph = G.Graph() | graph = G.Graph() | ||||
| for i, h in enumerate(self._args_bindings): | |||||
| for i, h in enumerate(self._arg_bindings): | |||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||||
| if arg_names: | |||||
| h2v[h].name = arg_names[i] | |||||
| for k, h in self._kwargs_bindings.items(): | |||||
| h2v[h] = graph.make_h2d( | |||||
| dtype=info.dtype, | |||||
| device=info.device, | |||||
| shape=info.shape, | |||||
| name=arg_names[i] if arg_names else None, | |||||
| ) | |||||
| for k, h in self._kwarg_bindings.items(): | |||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||||
| h2v[h].name = k | |||||
| h2v[h] = graph.make_h2d( | |||||
| dtype=info.dtype, device=info.device, shape=info.shape, name=k | |||||
| ) | |||||
| for op, ihandles, ohandles in self._seq: | for op, ihandles, ohandles in self._seq: | ||||
| ivars = [] | ivars = [] | ||||
| @@ -479,11 +482,12 @@ class trace: | |||||
| info.external = False | info.external = False | ||||
| info.device = x.device | info.device = x.device | ||||
| info.dtype = x.dtype | info.dtype = x.dtype | ||||
| info.shape = x.shape | |||||
| TraceMixin._TraceMixin__inject(x, h) | TraceMixin._TraceMixin__inject(x, h) | ||||
| self._inputs_to_restore.append(x) | self._inputs_to_restore.append(x) | ||||
| return h | return h | ||||
| self._args_bindings = [] | |||||
| self._arg_bindings = [] | |||||
| for i, x in enumerate(args): | for i, x in enumerate(args): | ||||
| x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
| if x is None: | if x is None: | ||||
| @@ -491,20 +495,20 @@ class trace: | |||||
| "positional arguments should all be tensor " | "positional arguments should all be tensor " | ||||
| "but args[%d] cannot be recognized as one" % i | "but args[%d] cannot be recognized as one" % i | ||||
| ) | ) | ||||
| self._args_bindings.append(record_input(x)) | |||||
| self._arg_bindings.append(record_input(x)) | |||||
| self._kwargs_bindings = {} | |||||
| self._kwarg_bindings = {} | |||||
| for k, x in kwargs.items(): | for k, x in kwargs.items(): | ||||
| x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
| if x is not None: | if x is not None: | ||||
| self._kwargs_bindings[k] = record_input(x) | |||||
| self._kwarg_bindings[k] = record_input(x) | |||||
| else: | else: | ||||
| if len(args) != len(self._args_bindings): | |||||
| if len(args) != len(self._arg_bindings): | |||||
| raise TraceMismatchError("positional argument length mismatch") | raise TraceMismatchError("positional argument length mismatch") | ||||
| self._tensor_remaps = {} | self._tensor_remaps = {} | ||||
| for i, (h, x) in enumerate(zip(self._args_bindings, args)): | |||||
| for i, (h, x) in enumerate(zip(self._arg_bindings, args)): | |||||
| x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
| if x is None: | if x is None: | ||||
| raise TypeError( | raise TypeError( | ||||
| @@ -524,9 +528,9 @@ class trace: | |||||
| x = find_raw_tensor(x) | x = find_raw_tensor(x) | ||||
| if x is not None: | if x is not None: | ||||
| kwargs_tensors[k] = x | kwargs_tensors[k] = x | ||||
| if set(kwargs_tensors) != set(self._kwargs_bindings): | |||||
| too_many = set(kwargs_tensors) - set(self._kwargs_bindings) | |||||
| too_few = set(self._kwargs_bindings) - set(kwargs_tensors) | |||||
| if set(kwargs_tensors) != set(self._kwarg_bindings): | |||||
| too_many = set(kwargs_tensors) - set(self._kwarg_bindings) | |||||
| too_few = set(self._kwarg_bindings) - set(kwargs_tensors) | |||||
| if too_many: | if too_many: | ||||
| raise TraceMismatchError( | raise TraceMismatchError( | ||||
| "keyword arguments found to be tensor this time " | "keyword arguments found to be tensor this time " | ||||
| @@ -537,7 +541,7 @@ class trace: | |||||
| "keyword arguments found to be non-tensor this time " | "keyword arguments found to be non-tensor this time " | ||||
| "but were tensor previously: %s" % " ".join(too_few) | "but were tensor previously: %s" % " ".join(too_few) | ||||
| ) | ) | ||||
| for k, h in self._kwargs_bindings.items(): | |||||
| for k, h in self._kwarg_bindings.items(): | |||||
| x = kwargs_tensors[k] | x = kwargs_tensors[k] | ||||
| info = self._tinfo[h] | info = self._tinfo[h] | ||||
| if x.dtype != info.dtype: | if x.dtype != info.dtype: | ||||
| @@ -237,7 +237,7 @@ void init_graph_rt(py::module m) { | |||||
| return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | ||||
| }); | }); | ||||
| m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) { | |||||
| m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | |||||
| if (!cn.valid()) { | if (!cn.valid()) { | ||||
| throw py::type_error("device must be valid"); | throw py::type_error("device must be valid"); | ||||
| } | } | ||||
| @@ -248,8 +248,8 @@ void init_graph_rt(py::module m) { | |||||
| if (name) { | if (name) { | ||||
| config.name(*name); | config.name(*name); | ||||
| } | } | ||||
| return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node(); | |||||
| }, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); | |||||
| return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node(); | |||||
| }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); | |||||
| m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | ||||
| const CompNode& comp_node, | const CompNode& comp_node, | ||||
| @@ -0,0 +1,136 @@ | |||||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
| # | |||||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, | |||||
| # software distributed under the License is distributed on an | |||||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| import contextlib | |||||
| import os | |||||
| import tempfile | |||||
| import numpy as np | |||||
| import megengine as mge | |||||
| import megengine.functional as F | |||||
| import megengine.module as M | |||||
| import megengine.optimizer as optim | |||||
| from megengine import tensor | |||||
| from megengine.jit import trace | |||||
| @contextlib.contextmanager | |||||
| def mkstemp(): | |||||
| fd, path = tempfile.mkstemp() | |||||
| try: | |||||
| os.close(fd) | |||||
| yield path | |||||
| finally: | |||||
| os.remove(path) | |||||
| def minibatch_generator(batch_size): | |||||
| while True: | |||||
| inp_data = np.zeros((batch_size, 2)) | |||||
| label = np.zeros(batch_size, dtype=np.int32) | |||||
| for i in range(batch_size): | |||||
| inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||||
| label[i] = 1 if np.prod(inp_data[i]) < 0 else 0 | |||||
| yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)} | |||||
| class XORNet(M.Module): | |||||
| def __init__(self): | |||||
| self.mid_dim = 14 | |||||
| self.num_class = 2 | |||||
| super().__init__() | |||||
| self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True) | |||||
| self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True) | |||||
| self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True) | |||||
| def forward(self, x): | |||||
| x = self.fc0(x) | |||||
| x = F.tanh(x) | |||||
| x = self.fc1(x) | |||||
| x = F.tanh(x) | |||||
| x = self.fc2(x) | |||||
| return x | |||||
| def test_xornet_trace_dump(): | |||||
| net = XORNet() | |||||
| opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) | |||||
| batch_size = 64 | |||||
| train_dataset = minibatch_generator(batch_size) | |||||
| val_dataset = minibatch_generator(batch_size) | |||||
| @trace | |||||
| def train_fun(data, label): | |||||
| with opt.record(): | |||||
| net.train() | |||||
| pred = net(data) | |||||
| loss = F.cross_entropy_with_softmax(pred, label) | |||||
| opt.backward(loss) | |||||
| return pred, loss | |||||
| @trace | |||||
| def val_fun(data, label): | |||||
| net.eval() | |||||
| pred = net(data) | |||||
| loss = F.cross_entropy_with_softmax(pred, label) | |||||
| return pred, loss | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def pred_fun(data): | |||||
| net.eval() | |||||
| pred = net(data) | |||||
| pred_normalized = F.softmax(pred) | |||||
| return pred_normalized | |||||
| train_loss = [] | |||||
| val_loss = [] | |||||
| for step, minibatch in enumerate(train_dataset): | |||||
| if step > 100: | |||||
| break | |||||
| data = tensor(minibatch["data"]) | |||||
| label = tensor(minibatch["label"]) | |||||
| opt.zero_grad() | |||||
| _, loss = train_fun(data, label) | |||||
| train_loss.append((step, loss.numpy())) | |||||
| if step % 50 == 0: | |||||
| minibatch = next(val_dataset) | |||||
| _, loss = val_fun(data, label) | |||||
| loss = loss.numpy()[0] | |||||
| val_loss.append((step, loss)) | |||||
| print("Step: {} loss={}".format(step, loss)) | |||||
| opt.step() | |||||
| test_data = np.array( | |||||
| [ | |||||
| (0.5, 0.5), | |||||
| (0.3, 0.7), | |||||
| (0.1, 0.9), | |||||
| (-0.5, -0.5), | |||||
| (-0.3, -0.7), | |||||
| (-0.9, -0.1), | |||||
| (0.5, -0.5), | |||||
| (0.3, -0.7), | |||||
| (0.9, -0.1), | |||||
| (-0.5, 0.5), | |||||
| (-0.3, 0.7), | |||||
| (-0.1, 0.9), | |||||
| ] | |||||
| ) | |||||
| data = tensor(test_data.astype(np.float32)) | |||||
| out = pred_fun(data) | |||||
| pred_output = out.numpy() | |||||
| pred_label = np.argmax(pred_output, 1) | |||||
| with np.printoptions(precision=4, suppress=True): | |||||
| print("Predicated probability:") | |||||
| print(pred_output) | |||||
| with mkstemp() as out: | |||||
| pred_fun.dump(out, arg_names=["data"], output_names=["label"]) | |||||