GitOrigin-RevId: 5e2acd4052
tags/v1.0.0-rc1
| @@ -78,9 +78,9 @@ class Graph(_imperative_rt.ComputingGraph): | |||
| opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | |||
| return opnode.outputs[0] | |||
| def make_h2d(self, *, dtype, device): | |||
| def make_h2d(self, *, dtype, device, shape=None, name=None): | |||
| device = as_device(device).to_c() | |||
| return self._wrap(_imperative_rt.make_h2d(self, device, dtype)) | |||
| return self._wrap(_imperative_rt.make_h2d(self, device, dtype, shape, name)) | |||
| def dump(*args): | |||
| @@ -51,6 +51,7 @@ class TensorInfo: | |||
| "value_read", | |||
| "device", | |||
| "dtype", | |||
| "shape", | |||
| "bound_data", | |||
| # resources for execution | |||
| "varnode", | |||
| @@ -107,8 +108,8 @@ class trace: | |||
| self._active_tensors = weakref.WeakSet() | |||
| self._tensor_remaps = None | |||
| self._inputs_to_restore = None | |||
| self._args_bindings = None | |||
| self._kwargs_bindings = None | |||
| self._arg_bindings = None | |||
| self._kwarg_bindings = None | |||
| self._output_bindings = None | |||
| self._output_names = None | |||
| @@ -329,9 +330,7 @@ class trace: | |||
| links = () | |||
| if self._capture_as_const: | |||
| for h in itertools.chain( | |||
| self._args_bindings, self._kwargs_bindings.values() | |||
| ): | |||
| for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()): | |||
| info = self._tinfo[h] | |||
| opnode = info.data_setter = G.InputNode( | |||
| device=info.device, dtype=info.dtype, graph=graph | |||
| @@ -434,15 +433,19 @@ class trace: | |||
| h2v = {} | |||
| graph = G.Graph() | |||
| for i, h in enumerate(self._args_bindings): | |||
| for i, h in enumerate(self._arg_bindings): | |||
| info = self._tinfo[h] | |||
| h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||
| if arg_names: | |||
| h2v[h].name = arg_names[i] | |||
| for k, h in self._kwargs_bindings.items(): | |||
| h2v[h] = graph.make_h2d( | |||
| dtype=info.dtype, | |||
| device=info.device, | |||
| shape=info.shape, | |||
| name=arg_names[i] if arg_names else None, | |||
| ) | |||
| for k, h in self._kwarg_bindings.items(): | |||
| info = self._tinfo[h] | |||
| h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) | |||
| h2v[h].name = k | |||
| h2v[h] = graph.make_h2d( | |||
| dtype=info.dtype, device=info.device, shape=info.shape, name=k | |||
| ) | |||
| for op, ihandles, ohandles in self._seq: | |||
| ivars = [] | |||
| @@ -479,11 +482,12 @@ class trace: | |||
| info.external = False | |||
| info.device = x.device | |||
| info.dtype = x.dtype | |||
| info.shape = x.shape | |||
| TraceMixin._TraceMixin__inject(x, h) | |||
| self._inputs_to_restore.append(x) | |||
| return h | |||
| self._args_bindings = [] | |||
| self._arg_bindings = [] | |||
| for i, x in enumerate(args): | |||
| x = find_raw_tensor(x) | |||
| if x is None: | |||
| @@ -491,20 +495,20 @@ class trace: | |||
| "positional arguments should all be tensor " | |||
| "but args[%d] cannot be recognized as one" % i | |||
| ) | |||
| self._args_bindings.append(record_input(x)) | |||
| self._arg_bindings.append(record_input(x)) | |||
| self._kwargs_bindings = {} | |||
| self._kwarg_bindings = {} | |||
| for k, x in kwargs.items(): | |||
| x = find_raw_tensor(x) | |||
| if x is not None: | |||
| self._kwargs_bindings[k] = record_input(x) | |||
| self._kwarg_bindings[k] = record_input(x) | |||
| else: | |||
| if len(args) != len(self._args_bindings): | |||
| if len(args) != len(self._arg_bindings): | |||
| raise TraceMismatchError("positional argument length mismatch") | |||
| self._tensor_remaps = {} | |||
| for i, (h, x) in enumerate(zip(self._args_bindings, args)): | |||
| for i, (h, x) in enumerate(zip(self._arg_bindings, args)): | |||
| x = find_raw_tensor(x) | |||
| if x is None: | |||
| raise TypeError( | |||
| @@ -524,9 +528,9 @@ class trace: | |||
| x = find_raw_tensor(x) | |||
| if x is not None: | |||
| kwargs_tensors[k] = x | |||
| if set(kwargs_tensors) != set(self._kwargs_bindings): | |||
| too_many = set(kwargs_tensors) - set(self._kwargs_bindings) | |||
| too_few = set(self._kwargs_bindings) - set(kwargs_tensors) | |||
| if set(kwargs_tensors) != set(self._kwarg_bindings): | |||
| too_many = set(kwargs_tensors) - set(self._kwarg_bindings) | |||
| too_few = set(self._kwarg_bindings) - set(kwargs_tensors) | |||
| if too_many: | |||
| raise TraceMismatchError( | |||
| "keyword arguments found to be tensor this time " | |||
| @@ -537,7 +541,7 @@ class trace: | |||
| "keyword arguments found to be non-tensor this time " | |||
| "but were tensor previously: %s" % " ".join(too_few) | |||
| ) | |||
| for k, h in self._kwargs_bindings.items(): | |||
| for k, h in self._kwarg_bindings.items(): | |||
| x = kwargs_tensors[k] | |||
| info = self._tinfo[h] | |||
| if x.dtype != info.dtype: | |||
| @@ -237,7 +237,7 @@ void init_graph_rt(py::module m) { | |||
| return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||
| }); | |||
| m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) { | |||
| m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, TensorShape shape, std::optional<std::string> name) { | |||
| if (!cn.valid()) { | |||
| throw py::type_error("device must be valid"); | |||
| } | |||
| @@ -248,8 +248,8 @@ void init_graph_rt(py::module m) { | |||
| if (name) { | |||
| config.name(*name); | |||
| } | |||
| return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node(); | |||
| }, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); | |||
| return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, shape, dtype), config).node(); | |||
| }, py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::none()); | |||
| m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | |||
| const CompNode& comp_node, | |||
| @@ -0,0 +1,136 @@ | |||
| # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
| # | |||
| # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, | |||
| # software distributed under the License is distributed on an | |||
| # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| import contextlib | |||
| import os | |||
| import tempfile | |||
| import numpy as np | |||
| import megengine as mge | |||
| import megengine.functional as F | |||
| import megengine.module as M | |||
| import megengine.optimizer as optim | |||
| from megengine import tensor | |||
| from megengine.jit import trace | |||
| @contextlib.contextmanager | |||
| def mkstemp(): | |||
| fd, path = tempfile.mkstemp() | |||
| try: | |||
| os.close(fd) | |||
| yield path | |||
| finally: | |||
| os.remove(path) | |||
| def minibatch_generator(batch_size): | |||
| while True: | |||
| inp_data = np.zeros((batch_size, 2)) | |||
| label = np.zeros(batch_size, dtype=np.int32) | |||
| for i in range(batch_size): | |||
| inp_data[i, :] = np.random.rand(2) * 2 - 1 | |||
| label[i] = 1 if np.prod(inp_data[i]) < 0 else 0 | |||
| yield {"data": inp_data.astype(np.float32), "label": label.astype(np.int32)} | |||
| class XORNet(M.Module): | |||
| def __init__(self): | |||
| self.mid_dim = 14 | |||
| self.num_class = 2 | |||
| super().__init__() | |||
| self.fc0 = M.Linear(self.num_class, self.mid_dim, bias=True) | |||
| self.fc1 = M.Linear(self.mid_dim, self.mid_dim, bias=True) | |||
| self.fc2 = M.Linear(self.mid_dim, self.num_class, bias=True) | |||
| def forward(self, x): | |||
| x = self.fc0(x) | |||
| x = F.tanh(x) | |||
| x = self.fc1(x) | |||
| x = F.tanh(x) | |||
| x = self.fc2(x) | |||
| return x | |||
| def test_xornet_trace_dump(): | |||
| net = XORNet() | |||
| opt = optim.SGD(net.parameters(requires_grad=True), lr=0.01, momentum=0.9) | |||
| batch_size = 64 | |||
| train_dataset = minibatch_generator(batch_size) | |||
| val_dataset = minibatch_generator(batch_size) | |||
| @trace | |||
| def train_fun(data, label): | |||
| with opt.record(): | |||
| net.train() | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| opt.backward(loss) | |||
| return pred, loss | |||
| @trace | |||
| def val_fun(data, label): | |||
| net.eval() | |||
| pred = net(data) | |||
| loss = F.cross_entropy_with_softmax(pred, label) | |||
| return pred, loss | |||
| @trace(symbolic=True, capture_as_const=True) | |||
| def pred_fun(data): | |||
| net.eval() | |||
| pred = net(data) | |||
| pred_normalized = F.softmax(pred) | |||
| return pred_normalized | |||
| train_loss = [] | |||
| val_loss = [] | |||
| for step, minibatch in enumerate(train_dataset): | |||
| if step > 100: | |||
| break | |||
| data = tensor(minibatch["data"]) | |||
| label = tensor(minibatch["label"]) | |||
| opt.zero_grad() | |||
| _, loss = train_fun(data, label) | |||
| train_loss.append((step, loss.numpy())) | |||
| if step % 50 == 0: | |||
| minibatch = next(val_dataset) | |||
| _, loss = val_fun(data, label) | |||
| loss = loss.numpy()[0] | |||
| val_loss.append((step, loss)) | |||
| print("Step: {} loss={}".format(step, loss)) | |||
| opt.step() | |||
| test_data = np.array( | |||
| [ | |||
| (0.5, 0.5), | |||
| (0.3, 0.7), | |||
| (0.1, 0.9), | |||
| (-0.5, -0.5), | |||
| (-0.3, -0.7), | |||
| (-0.9, -0.1), | |||
| (0.5, -0.5), | |||
| (0.3, -0.7), | |||
| (0.9, -0.1), | |||
| (-0.5, 0.5), | |||
| (-0.3, 0.7), | |||
| (-0.1, 0.9), | |||
| ] | |||
| ) | |||
| data = tensor(test_data.astype(np.float32)) | |||
| out = pred_fun(data) | |||
| pred_output = out.numpy() | |||
| pred_label = np.argmax(pred_output, 1) | |||
| with np.printoptions(precision=4, suppress=True): | |||
| print("Predicated probability:") | |||
| print(pred_output) | |||
| with mkstemp() as out: | |||
| pred_fun.dump(out, arg_names=["data"], output_names=["label"]) | |||