|
|
|
@@ -92,7 +92,7 @@ class BNNLeNet5(nn.Cell): |
|
|
|
def train_model(train_net, net, dataset): |
|
|
|
accs = [] |
|
|
|
loss_sum = 0 |
|
|
|
for _, data in enumerate(dataset.create_dict_iterator(output_numpy=True)): |
|
|
|
for _, data in enumerate(dataset.create_dict_iterator(output_numpy=True, num_epochs=1)): |
|
|
|
train_x = Tensor(data['image'].astype(np.float32)) |
|
|
|
label = Tensor(data['label'].astype(np.int32)) |
|
|
|
loss = train_net(train_x, label) |
|
|
|
@@ -109,7 +109,7 @@ def train_model(train_net, net, dataset): |
|
|
|
|
|
|
|
def validate_model(net, dataset): |
|
|
|
accs = [] |
|
|
|
for _, data in enumerate(dataset.create_dict_iterator(output_numpy=True)): |
|
|
|
for _, data in enumerate(dataset.create_dict_iterator(output_numpy=True, num_epochs=1)): |
|
|
|
train_x = Tensor(data['image'].astype(np.float32)) |
|
|
|
label = Tensor(data['label'].astype(np.int32)) |
|
|
|
output = net(train_x) |
|
|
|
|