| @@ -1,6 +1,7 @@ | |||||
| import numpy as np | import numpy as np | ||||
| import megengine as mge | import megengine as mge | ||||
| import megengine.autodiff as ad | |||||
| import megengine.functional as F | import megengine.functional as F | ||||
| import megengine.module as M | import megengine.module as M | ||||
| import megengine.optimizer as optim | import megengine.optimizer as optim | ||||
| @@ -35,57 +36,54 @@ class XORNet(M.Module): | |||||
| return x | return x | ||||
| @trace(symbolic=True) | |||||
| def train_fun(data, label, net=None, opt=None): | |||||
| net.train() | |||||
| pred = net(data) | |||||
| loss = F.cross_entropy_with_softmax(pred, label) | |||||
| opt.backward(loss) | |||||
| return pred, loss | |||||
| @trace(symbolic=True) | |||||
| def val_fun(data, label, net=None): | |||||
| net.eval() | |||||
| pred = net(data) | |||||
| loss = F.cross_entropy_with_softmax(pred, label) | |||||
| return pred, loss | |||||
| @trace(symbolic=True) | |||||
| def pred_fun(data, net=None): | |||||
| net.eval() | |||||
| pred = net(data) | |||||
| pred_normalized = F.softmax(pred) | |||||
| return pred_normalized | |||||
| def main(): | def main(): | ||||
| if not mge.is_cuda_available(): | if not mge.is_cuda_available(): | ||||
| mge.set_default_device("cpux") | mge.set_default_device("cpux") | ||||
| net = XORNet() | net = XORNet() | ||||
| gm = ad.GradManager().attach(net.parameters()) | |||||
| opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | opt = optim.SGD(net.parameters(), lr=0.01, momentum=0.9) | ||||
| batch_size = 64 | batch_size = 64 | ||||
| train_dataset = minibatch_generator(batch_size) | train_dataset = minibatch_generator(batch_size) | ||||
| val_dataset = minibatch_generator(batch_size) | val_dataset = minibatch_generator(batch_size) | ||||
| data = mge.tensor() | |||||
| label = mge.tensor(np.zeros((batch_size,)), dtype=np.int32) | |||||
| def train_fun(data, label): | |||||
| opt.clear_grad() | |||||
| with gm: | |||||
| pred = net(data) | |||||
| loss = F.cross_entropy_with_softmax(pred, label) | |||||
| gm.backward(loss) | |||||
| opt.step() | |||||
| return pred, loss | |||||
| def val_fun(data, label): | |||||
| pred = net(data) | |||||
| loss = F.cross_entropy_with_softmax(pred, label) | |||||
| return pred, loss | |||||
| @trace(symbolic=True, capture_as_const=True) | |||||
| def pred_fun(data): | |||||
| pred = net(data) | |||||
| pred_normalized = F.softmax(pred) | |||||
| return pred_normalized | |||||
| data = np.random.random((batch_size, 2)).astype(np.float32) | |||||
| label = np.zeros((batch_size,)).astype(np.int32) | |||||
| train_loss = [] | train_loss = [] | ||||
| val_loss = [] | val_loss = [] | ||||
| for step, minibatch in enumerate(train_dataset): | for step, minibatch in enumerate(train_dataset): | ||||
| if step > 1000: | if step > 1000: | ||||
| break | break | ||||
| data.set_value(minibatch["data"]) | |||||
| label.set_value(minibatch["label"]) | |||||
| opt.zero_grad() | |||||
| _, loss = train_fun(data, label, net=net, opt=opt) | |||||
| data = minibatch["data"] | |||||
| label = minibatch["label"] | |||||
| net.train() | |||||
| _, loss = train_fun(data, label) | |||||
| train_loss.append((step, loss.numpy())) | train_loss.append((step, loss.numpy())) | ||||
| if step % 50 == 0: | if step % 50 == 0: | ||||
| minibatch = next(val_dataset) | minibatch = next(val_dataset) | ||||
| _, loss = val_fun(data, label, net=net) | |||||
| net.eval() | |||||
| _, loss = val_fun(data, label) | |||||
| loss = loss.numpy()[0] | loss = loss.numpy()[0] | ||||
| val_loss.append((step, loss)) | val_loss.append((step, loss)) | ||||
| print("Step: {} loss={}".format(step, loss)) | print("Step: {} loss={}".format(step, loss)) | ||||
| @@ -108,8 +106,10 @@ def main(): | |||||
| ] | ] | ||||
| ) | ) | ||||
| data.set_value(test_data) | |||||
| out = pred_fun(data, net=net) | |||||
| # tracing only accepts tensor as input | |||||
| data = mge.tensor(test_data, dtype=np.float32) | |||||
| net.eval() | |||||
| out = pred_fun(data) | |||||
| pred_output = out.numpy() | pred_output = out.numpy() | ||||
| pred_label = np.argmax(pred_output, 1) | pred_label = np.argmax(pred_output, 1) | ||||
| @@ -125,11 +125,8 @@ def main(): | |||||
| model_name = "xornet_deploy.mge" | model_name = "xornet_deploy.mge" | ||||
| if pred_fun.enabled: | |||||
| print("Dump model as {}".format(model_name)) | |||||
| pred_fun.dump(model_name, arg_names=["data"]) | |||||
| else: | |||||
| print("pred_fun must be run with trace enabled in order to dump model") | |||||
| print("Dump model as {}".format(model_name)) | |||||
| pred_fun.dump(model_name, arg_names=["data"]) | |||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||