Browse Source

fix lite train codecheck

tags/v1.6.0
zhengjun10 4 years ago
parent
commit
301d1101ae
4 changed files with 8 additions and 11 deletions
  1. +6
    -9
      mindspore/lite/examples/export_models/models/tinybert_train_export.py
  2. +1
    -0
      mindspore/lite/examples/train_lenet/model/train_utils.py
  3. +0
    -1
      mindspore/lite/src/train/opt_allocator.cc
  4. +1
    -1
      mindspore/lite/src/train/train_session.cc

+ 6
- 9
mindspore/lite/examples/export_models/models/tinybert_train_export.py View File

@@ -32,15 +32,12 @@ else:
path = ''
sys.path.append(os.environ['CLOUD_MODEL_ZOO'] + 'official/nlp/tinybert')

from official.nlp.tinybert.src.tinybert_model import TinyBertModel
from official.nlp.tinybert.src.model_utils.config import bert_student_net_cfg
from train_utils import save_t
from official.nlp.tinybert.src.tinybert_model import TinyBertModel # noqa: 402
from official.nlp.tinybert.src.model_utils.config import bert_student_net_cfg # noqa: 402
from train_utils import save_t # noqa: 402





class BertNetworkWithLoss_gd(M.nn.Cell):
class BertNetworkWithLossGenDistill(M.nn.Cell):
"""
Provide bert pre-training loss through network.
Args:
@@ -53,7 +50,7 @@ class BertNetworkWithLoss_gd(M.nn.Cell):

def __init__(self, student_config, is_training, use_one_hot_embeddings=False,
is_att_fit=False, is_rep_fit=True):
super(BertNetworkWithLoss_gd, self).__init__()
super(BertNetworkWithLossGenDistill, self).__init__()
# load teacher model
self.bert = TinyBertModel(
student_config, is_training, use_one_hot_embeddings)
@@ -169,7 +166,7 @@ bert_student_net_cfg.attention_probs_dropout_prob = 0.0
bert_student_net_cfg.compute_type = mstype.float32

#==============Training===============
nloss = BertNetworkWithLoss_gd(
nloss = BertNetworkWithLossGenDistill(
bert_student_net_cfg, is_training=True, use_one_hot_embeddings=False)
optimizer = M.nn.Adam(nloss.bert.trainable_params(), learning_rate=1e-3, beta1=0.5, beta2=0.7,
eps=1e-2, use_locking=True, use_nesterov=False, weight_decay=0.1, loss_scale=0.3)


+ 1
- 0
mindspore/lite/examples/train_lenet/model/train_utils.py View File

@@ -18,6 +18,7 @@ import mindspore.nn as nn
from mindspore.common.parameter import ParameterTuple
from mindspore import amp


def train_wrap(net, loss_fn=None, optimizer=None, weights=None, loss_scale_manager=None):
"""
train_wrap


+ 0
- 1
mindspore/lite/src/train/opt_allocator.cc View File

@@ -18,7 +18,6 @@
#include "nnacl/op_base.h"

namespace mindspore {

size_t OptAllocator::FindFree(size_t size) {
size_t min_size = std::numeric_limits<size_t>::max();
size_t min_addr = std::numeric_limits<size_t>::max();


+ 1
- 1
mindspore/lite/src/train/train_session.cc View File

@@ -1093,7 +1093,7 @@ session::LiteSession *session::TrainSession::CreateTrainSession(const std::strin
return nullptr;
}
if (context->allocator == nullptr) {
const_cast<lite::Context *>(context)->allocator = std::shared_ptr<Allocator>(new (std::nothrow) StaticAllocator());
const_cast<lite::Context *>(context)->allocator = std::make_shared<StaticAllocator>();
if (context->allocator == nullptr) {
MS_LOG(ERROR) << " cannot convert to static allocation";
}


Loading…
Cancel
Save