| @@ -63,7 +63,10 @@ def train(data, label, net, opt): | |||||
| def update_model(model_path): | def update_model(model_path): | ||||
| """ | """ | ||||
| Update the dumped model with test cases for new reference values | |||||
| Update the dumped model with test cases for new reference values. | |||||
| The model with pre-trained weights is trained for one iter with the test data attached. | |||||
| The loss and updated net state dict is dumped. | |||||
| """ | """ | ||||
| net = MnistNet(has_bn=True) | net = MnistNet(has_bn=True) | ||||
| checkpoint = mge.load(model_path) | checkpoint = mge.load(model_path) | ||||
| @@ -89,9 +92,6 @@ def run_test(model_path, use_jit, use_symbolic): | |||||
| """ | """ | ||||
| Load the model with test cases and run the training for one iter. | Load the model with test cases and run the training for one iter. | ||||
| The loss and updated weights are compared with reference value to verify the correctness. | The loss and updated weights are compared with reference value to verify the correctness. | ||||
| The model with pre-trained weights is trained for one iter and the net state dict is dumped. | |||||
| The test cases is appended to the model file. The reference result is obtained | |||||
| by running the train for one iter. | |||||
| Dump a new file with updated result by calling update_model | Dump a new file with updated result by calling update_model | ||||
| if you think the test fails due to numerical rounding errors instead of bugs. | if you think the test fails due to numerical rounding errors instead of bugs. | ||||
| @@ -109,7 +109,7 @@ def run_test(model_path, use_jit, use_symbolic): | |||||
| data.set_value(checkpoint["data"]) | data.set_value(checkpoint["data"]) | ||||
| label.set_value(checkpoint["label"]) | label.set_value(checkpoint["label"]) | ||||
| max_err = 0.0 | |||||
| max_err = 1e-1 | |||||
| train_func = train | train_func = train | ||||
| if use_jit: | if use_jit: | ||||