Browse Source

[to #42322933] Fix palm concurrency

修复 palm 模型部署并发报错问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9649010

    * fix palm concurrency
master
hemu.zp yingda.chen 3 years ago
parent
commit
4fb4947397
1 changed files with 93 additions and 115 deletions
  1. +93
    -115
      modelscope/models/nlp/palm_v2/modeling_palm.py

+ 93
- 115
modelscope/models/nlp/palm_v2/modeling_palm.py View File

@@ -4,20 +4,19 @@ import math
import os import os
import subprocess import subprocess
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union


import json import json
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn
from torch import Tensor, nn
from torch.nn.init import xavier_uniform_ from torch.nn.init import xavier_uniform_
from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig, from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig,
RobertaModel, RobertaTokenizer) RobertaModel, RobertaTokenizer)
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel


from modelscope.outputs import OutputKeys
from modelscope.utils import logger as logging from modelscope.utils import logger as logging
from .configuration_palm import PalmConfig from .configuration_palm import PalmConfig
from .dureader_eval import compute_bleu_rouge, normalize from .dureader_eval import compute_bleu_rouge, normalize
@@ -142,35 +141,27 @@ class MultiHeadedAttention(nn.Module): # SelfAttention
key = shape(key) key = shape(key)
value = shape(value) value = shape(value)


if layer_cache is not None:
device = key.device
if layer_cache['self_keys'] is not None:
key = torch.cat(
(layer_cache['self_keys'].to(device), key), dim=2)
if layer_cache['self_values'] is not None:
value = torch.cat(
(layer_cache['self_values'].to(device), value),
dim=2)
layer_cache['self_keys'] = key
layer_cache['self_values'] = value
device = key.device
if layer_cache['self_keys'] is not None:
key = torch.cat((layer_cache['self_keys'].to(device), key),
dim=2)
if layer_cache['self_values'] is not None:
value = torch.cat(
(layer_cache['self_values'].to(device), value), dim=2)
layer_cache['self_keys'] = key
layer_cache['self_values'] = value
elif type == 'context': elif type == 'context':
query = self.linear_query(query) query = self.linear_query(query)
if layer_cache is not None:
if layer_cache['memory_keys'] is None:
key, value = self.linear_keys(key), self.linear_values(
value)
key = shape(key)
value = shape(value)
else:
key, value = layer_cache['memory_keys'], layer_cache[
'memory_values']
layer_cache['memory_keys'] = key
layer_cache['memory_values'] = value
else:
if layer_cache['memory_keys'] is None:
key, value = self.linear_keys(key), self.linear_values( key, value = self.linear_keys(key), self.linear_values(
value) value)
key = shape(key) key = shape(key)
value = shape(value) value = shape(value)
else:
key, value = layer_cache['memory_keys'], layer_cache[
'memory_values']
layer_cache['memory_keys'] = key
layer_cache['memory_values'] = value
else: else:
key = self.linear_keys(key) key = self.linear_keys(key)
value = self.linear_values(value) value = self.linear_values(value)
@@ -372,6 +363,44 @@ class PositionalEncoding(nn.Module):
return self.pe[:, :emb.size(1)] return self.pe[:, :emb.size(1)]




class TransformerDecoderState:

def __init__(self, src: Tensor, cache_num_layers: int = -1):
self.src: Tensor = src
self.previous_input: Tensor = None
self.previous_layer_inputs: Tensor = None
self.cache: Optional[Dict[str, Any]] = None
if cache_num_layers != -1:
self._init_cache(cache_num_layers)

def update_state(self, new_input, previous_layer_inputs):
self.previous_input = new_input
self.previous_layer_inputs = previous_layer_inputs
self.cache = None

def _init_cache(self, num_layers):
self.cache = {}
for num in range(num_layers):
layer_cache = {'memory_keys': None, 'memory_values': None}
layer_cache['self_keys'] = None
layer_cache['self_values'] = None
self.cache['layer_{}'.format(num)] = layer_cache

def map_batch_fn(self, fn):

def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = fn(v, batch_dim)

self.src = fn(self.src, 0)
if self.cache is not None:
_recursive_map(self.cache)


class TransformerDecoder(nn.Module): # Decoder class TransformerDecoder(nn.Module): # Decoder
""" """
The Transformer decoder from "Attention is All You Need". The Transformer decoder from "Attention is All You Need".
@@ -403,44 +432,6 @@ class TransformerDecoder(nn.Module): # Decoder
""" """
decoder_type = 'transformer' decoder_type = 'transformer'


class TransformerDecoderState:

def __init__(self, src):
self.src = src
self.previous_input = None
self.previous_layer_inputs = None
self.cache = None

def update_state(self, new_input, previous_layer_inputs):
self.previous_input = new_input
self.previous_layer_inputs = previous_layer_inputs
self.cache = None

def _init_cache(self, num_layers):
self.cache = {}
for num in range(num_layers):
layer_cache = {
'memory_keys': None,
'memory_values': None,
'self_keys': None,
'self_values': None
}
self.cache['layer_{}'.format(num)] = layer_cache

def map_batch_fn(self, fn):

def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = fn(v, batch_dim)

self.src = fn(self.src, 0)
if self.cache is not None:
_recursive_map(self.cache)

def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings): def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
super().__init__() super().__init__()


@@ -458,13 +449,13 @@ class TransformerDecoder(nn.Module): # Decoder
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.state = None self.state = None


def init_state(self, src, with_cache=False):
self.state = self.TransformerDecoderState(src)
if with_cache:
self.state._init_cache(self.num_layers)
def forward(self, tgt, memory_bank, step=None, memory_masks=None):
src_words = self.state.src
def forward(self,
state: TransformerDecoderState,
tgt: Tensor,
memory_bank: Tensor,
step: int = None,
memory_masks: Tensor = None):
src_words = state.src
tgt_words = tgt tgt_words = tgt
src_batch, src_len = src_words.size() src_batch, src_len = src_words.size()
tgt_batch, tgt_len = tgt_words.size() tgt_batch, tgt_len = tgt_words.size()
@@ -487,33 +478,36 @@ class TransformerDecoder(nn.Module): # Decoder
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
.expand(src_batch, tgt_len, src_len) .expand(src_batch, tgt_len, src_len)


if self.state.cache is None:
if state.cache is None:
saved_inputs = [] saved_inputs = []
attns = [] attns = []
for i in range(self.num_layers): for i in range(self.num_layers):
prev_layer_input = None prev_layer_input = None
if self.state.cache is None:
if self.state.previous_input is not None:
prev_layer_input = self.state.previous_layer_inputs[i]
if state.cache is None:
if state.previous_input is not None:
prev_layer_input = state.previous_layer_inputs[i]
output, attn, all_input \ output, attn, all_input \
= self.transformer_layers[i](output, src_memory_bank, src_pad_mask, tgt_pad_mask,
previous_input=prev_layer_input,
layer_cache=self.state.cache['layer_{}'.format(i)]
if self.state.cache is not None else None, step=step)
if self.state.cache is None:
= self.transformer_layers[i](
output, src_memory_bank,
src_pad_mask, tgt_pad_mask,
previous_input=prev_layer_input,
layer_cache=state.cache['layer_{}'.format(i)]
if state.cache is not None else None,
step=step)
if state.cache is None:
saved_inputs.append(all_input) saved_inputs.append(all_input)
attns.append(attn) attns.append(attn)


if self.state.cache is None:
if state.cache is None:
saved_inputs = torch.stack(saved_inputs) saved_inputs = torch.stack(saved_inputs)


output = self.layer_norm(output) output = self.layer_norm(output)


# Process the result and update the attentions. # Process the result and update the attentions.
if self.state.cache is None:
self.state.update_state(tgt, saved_inputs)
if state.cache is None:
state.update_state(tgt, saved_inputs)


return output, attns
return output, attns, state




class PalmPointerGenerator(nn.Module): class PalmPointerGenerator(nn.Module):
@@ -570,10 +564,11 @@ class AbsSummarizer(PalmPreTrainedModel): # Model
if (config.max_pos > 512): if (config.max_pos > 512):
my_pos_embeddings = nn.Embedding( my_pos_embeddings = nn.Embedding(
config.max_pos, self.bert.model.config.hidden_size) config.max_pos, self.bert.model.config.hidden_size)
my_pos_embeddings.weight.data[:512] = \
self.bert.embeddings.position_embeddings.weight.data
my_pos_embeddings.weight.data[512:] = \
self.bert.embeddings.position_embeddings.weight.data[-1][None, :].repeat(config.max_pos - 512, 1)
my_pos_embeddings.weight.data[:
512] = self.bert.embeddings.position_embeddings.weight.data
my_pos_embeddings.weight.data[
512:] = self.bert.embeddings.position_embeddings.weight.data[
-1][None, :].repeat(config.max_pos - 512, 1)
self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.bert.model.embeddings.position_embeddings = my_pos_embeddings
self.vocab_size = self.bert.config.vocab_size self.vocab_size = self.bert.config.vocab_size
tgt_embeddings = nn.Embedding( tgt_embeddings = nn.Embedding(
@@ -633,8 +628,8 @@ class AbsSummarizer(PalmPreTrainedModel): # Model


def forward(self, src, tgt, mask_src): def forward(self, src, tgt, mask_src):
top_vec, _ = self.bert(src, mask_src, return_dict=False) top_vec, _ = self.bert(src, mask_src, return_dict=False)
self.decoder.init_state(src)
decoder_outputs, attns = self.decoder(tgt[:, :-1], top_vec)
state = TransformerDecoderState(src)
decoder_outputs, attns, _ = self.decoder(state, tgt[:, :-1], top_vec)
return decoder_outputs, attns[-1], top_vec return decoder_outputs, attns[-1], top_vec




@@ -776,9 +771,9 @@ class Translator(nn.Module):
translation_batch['predictions'])) translation_batch['predictions']))
batch_size = batch.batch_size batch_size = batch.batch_size


preds, pred_score, _, tgt_str, src, src_str = \
translation_batch['predictions'], translation_batch['scores'], translation_batch['gold_score'], \
batch.tgt_str, batch.src, batch.src_str
preds, pred_score, tgt_str, src, src_str = translation_batch[
'predictions'], translation_batch[
'scores'], batch.tgt_str, batch.src, batch.src_str
query_id = batch.query_id query_id = batch.query_id
''' '''
try: try:
@@ -903,17 +898,13 @@ class Translator(nn.Module):
'</s>', '').replace('<unk>', ' ').strip() '</s>', '').replace('<unk>', ' ').strip()
if (self.args.recall_eval): if (self.args.recall_eval):
_pred_str = '' _pred_str = ''
# gap = 1e3
for sent in pred_str.split('<q>'): for sent in pred_str.split('<q>'):
can_pred_str = _pred_str + '<q>' + sent.strip() can_pred_str = _pred_str + '<q>' + sent.strip()
# can_gap = math.fabs(len(_pred_str.split()) - len(gold_str.split()))
# if(can_gap>=gap):
if len(can_pred_str.split()) >= len( if len(can_pred_str.split()) >= len(
gold_str.split()) + 10: gold_str.split()) + 10:
pred_str = _pred_str pred_str = _pred_str
break break
else: else:
# gap = can_gap
_pred_str = can_pred_str _pred_str = can_pred_str


if self.args.dataset == 'marco' or self.args.dataset == 'squad' or self.args.dataset == 'qg_ranking': if self.args.dataset == 'marco' or self.args.dataset == 'squad' or self.args.dataset == 'qg_ranking':
@@ -967,10 +958,6 @@ class Translator(nn.Module):
pred_str = [pred_str] pred_str = [pred_str]
pred_dict[query_id] = normalize([pred_str[0]]) pred_dict[query_id] = normalize([pred_str[0]])
ref_dict[query_id] = normalize([gold_str]) ref_dict[query_id] = normalize([gold_str])
# pred_str_list = [src] + pred_str
# self.can_out_file.write("\t".join(pred_str_list)+"\n")
# self.can_out_file.write("\t".join(pred_str_list)+"\n")
# self.gold_out_file.write("\t".join([src, pred_str[0], gold_str])+"\n")
self.pred_json_score_out_file.write( self.pred_json_score_out_file.write(
'\t'.join([str(query_id), src, gold_str, pred_str[0]]) '\t'.join([str(query_id), src, gold_str, pred_str[0]])
+ '\n') + '\n')
@@ -1027,8 +1014,6 @@ class Translator(nn.Module):
preds[idx] = '。' preds[idx] = '。'
return preds, labels return preds, labels


# bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
# self.logger.info('Dev eval result: {}'.format(bleu_rouge))
pred_results, gold_results = postprocess_text( pred_results, gold_results = postprocess_text(
pred_results, gold_results) pred_results, gold_results)
pred_dict = {str(i): tmp for i, tmp in enumerate(pred_results)} pred_dict = {str(i): tmp for i, tmp in enumerate(pred_results)}
@@ -1037,8 +1022,6 @@ class Translator(nn.Module):
print(bleu_rouge) print(bleu_rouge)
# unreachable # unreachable
elif self.args.dataset == 'dureader' or self.args.dataset == 'paraphrase': elif self.args.dataset == 'dureader' or self.args.dataset == 'paraphrase':
# bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
# self.logger.info('Dev eval result: {}'.format(bleu_rouge))
pred_results, gold_results = postprocess_text( pred_results, gold_results = postprocess_text(
pred_results, gold_results) pred_results, gold_results)
bleu_score = cal_bleu(pred_results, gold_results) bleu_score = cal_bleu(pred_results, gold_results)
@@ -1134,11 +1117,11 @@ class Translator(nn.Module):
mask_src = batch.mask_src mask_src = batch.mask_src


src_features, _ = self.model.bert(src, mask_src, return_dict=False) src_features, _ = self.model.bert(src, mask_src, return_dict=False)
self.model.decoder.init_state(src, with_cache=True)
state = TransformerDecoderState(src, self.model.decoder.num_layers)
device = src_features.device device = src_features.device


# Tile states and memory beam_size times. # Tile states and memory beam_size times.
self.model.decoder.state.map_batch_fn(
state.map_batch_fn(
lambda state, dim: self._tile(state, beam_size, dim=dim)) lambda state, dim: self._tile(state, beam_size, dim=dim))
src_features = self._tile(src_features, beam_size, dim=0) src_features = self._tile(src_features, beam_size, dim=0)
batch_offset = torch.arange( batch_offset = torch.arange(
@@ -1174,8 +1157,8 @@ class Translator(nn.Module):


# Decoder forward. # Decoder forward.
decoder_input = decoder_input.transpose(0, 1) decoder_input = decoder_input.transpose(0, 1)
dec_out, attns = self.model.decoder(
decoder_input, src_features, step=step)
dec_out, attns, state = self.model.decoder(
state, decoder_input, src_features, step=step)


# Generator forward. # Generator forward.
log_probs = self.generator.forward( log_probs = self.generator.forward(
@@ -1188,7 +1171,6 @@ class Translator(nn.Module):
# Multiply probs by the beam probability. # Multiply probs by the beam probability.


length_penalty = ((5.0 + (step + 1)) / 6.0)**self.alpha length_penalty = ((5.0 + (step + 1)) / 6.0)**self.alpha
# '''
if self.args.sample_topk: if self.args.sample_topk:
temperature = self.args.temperature temperature = self.args.temperature
_scores = log_probs / temperature _scores = log_probs / temperature
@@ -1211,13 +1193,11 @@ class Translator(nn.Module):
_scores = _scores / length_penalty _scores = _scores / length_penalty
topk_scores = torch.gather( topk_scores = torch.gather(
_scores, -1, topk_ids) # (batch_size * num_beams, 2) _scores, -1, topk_ids) # (batch_size * num_beams, 2)
# log_probs += # (batch_size * num_beams, 2)
# Match shape of greedy beam search # Match shape of greedy beam search
topk_ids = topk_ids.view( topk_ids = topk_ids.view(
-1, beam_size) # (batch_size, 2 * num_beams) -1, beam_size) # (batch_size, 2 * num_beams)
topk_scores = topk_scores.view( topk_scores = topk_scores.view(
-1, beam_size) # (batch_size, 2 * num_beams) -1, beam_size) # (batch_size, 2 * num_beams)
# '''
else: else:
log_probs += topk_log_probs.view(-1).unsqueeze(1) log_probs += topk_log_probs.view(-1).unsqueeze(1)
curr_scores = log_probs / length_penalty curr_scores = log_probs / length_penalty
@@ -1231,7 +1211,6 @@ class Translator(nn.Module):
fail = False fail = False
words = [int(w) for w in alive_seq[i]] words = [int(w) for w in alive_seq[i]]
if self.args.encoder == 'roberta': if self.args.encoder == 'roberta':
# words = [self.vocab.convert_ids_to_tokens[w] for w in words]
words = self.vocab.decode(words).strip().split() words = self.vocab.decode(words).strip().split()
else: else:
words = [ words = [
@@ -1252,7 +1231,6 @@ class Translator(nn.Module):
topk_log_probs = topk_scores * length_penalty topk_log_probs = topk_scores * length_penalty


# Resolve beam origin and true word ids. # Resolve beam origin and true word ids.
# topk_beam_index = topk_ids.div(vocab_size)
topk_beam_index = topk_ids // vocab_size topk_beam_index = topk_ids // vocab_size
topk_ids = topk_ids.fmod(vocab_size) topk_ids = topk_ids.fmod(vocab_size)


@@ -1313,7 +1291,7 @@ class Translator(nn.Module):
# Reorder states. # Reorder states.
select_indices = batch_index.view(-1) select_indices = batch_index.view(-1)
src_features = src_features.index_select(0, select_indices) src_features = src_features.index_select(0, select_indices)
self.model.decoder.state.map_batch_fn(
state.map_batch_fn(
lambda state, dim: state.index_select(dim, select_indices)) lambda state, dim: state.index_select(dim, select_indices))


return results return results


Loading…
Cancel
Save