Browse Source

[to #42322933] Fix palm concurrency

修复 palm 模型部署并发报错问题
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9649010

    * fix palm concurrency
master
hemu.zp yingda.chen 3 years ago
parent
commit
4fb4947397
1 changed files with 93 additions and 115 deletions
  1. +93
    -115
      modelscope/models/nlp/palm_v2/modeling_palm.py

+ 93
- 115
modelscope/models/nlp/palm_v2/modeling_palm.py View File

@@ -4,20 +4,19 @@ import math
import os
import subprocess
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import json
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor, nn
from torch.nn.init import xavier_uniform_
from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig,
RobertaModel, RobertaTokenizer)
from transformers.activations import ACT2FN
from transformers.modeling_utils import PreTrainedModel

from modelscope.outputs import OutputKeys
from modelscope.utils import logger as logging
from .configuration_palm import PalmConfig
from .dureader_eval import compute_bleu_rouge, normalize
@@ -142,35 +141,27 @@ class MultiHeadedAttention(nn.Module): # SelfAttention
key = shape(key)
value = shape(value)

if layer_cache is not None:
device = key.device
if layer_cache['self_keys'] is not None:
key = torch.cat(
(layer_cache['self_keys'].to(device), key), dim=2)
if layer_cache['self_values'] is not None:
value = torch.cat(
(layer_cache['self_values'].to(device), value),
dim=2)
layer_cache['self_keys'] = key
layer_cache['self_values'] = value
device = key.device
if layer_cache['self_keys'] is not None:
key = torch.cat((layer_cache['self_keys'].to(device), key),
dim=2)
if layer_cache['self_values'] is not None:
value = torch.cat(
(layer_cache['self_values'].to(device), value), dim=2)
layer_cache['self_keys'] = key
layer_cache['self_values'] = value
elif type == 'context':
query = self.linear_query(query)
if layer_cache is not None:
if layer_cache['memory_keys'] is None:
key, value = self.linear_keys(key), self.linear_values(
value)
key = shape(key)
value = shape(value)
else:
key, value = layer_cache['memory_keys'], layer_cache[
'memory_values']
layer_cache['memory_keys'] = key
layer_cache['memory_values'] = value
else:
if layer_cache['memory_keys'] is None:
key, value = self.linear_keys(key), self.linear_values(
value)
key = shape(key)
value = shape(value)
else:
key, value = layer_cache['memory_keys'], layer_cache[
'memory_values']
layer_cache['memory_keys'] = key
layer_cache['memory_values'] = value
else:
key = self.linear_keys(key)
value = self.linear_values(value)
@@ -372,6 +363,44 @@ class PositionalEncoding(nn.Module):
return self.pe[:, :emb.size(1)]


class TransformerDecoderState:

def __init__(self, src: Tensor, cache_num_layers: int = -1):
self.src: Tensor = src
self.previous_input: Tensor = None
self.previous_layer_inputs: Tensor = None
self.cache: Optional[Dict[str, Any]] = None
if cache_num_layers != -1:
self._init_cache(cache_num_layers)

def update_state(self, new_input, previous_layer_inputs):
self.previous_input = new_input
self.previous_layer_inputs = previous_layer_inputs
self.cache = None

def _init_cache(self, num_layers):
self.cache = {}
for num in range(num_layers):
layer_cache = {'memory_keys': None, 'memory_values': None}
layer_cache['self_keys'] = None
layer_cache['self_values'] = None
self.cache['layer_{}'.format(num)] = layer_cache

def map_batch_fn(self, fn):

def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = fn(v, batch_dim)

self.src = fn(self.src, 0)
if self.cache is not None:
_recursive_map(self.cache)


class TransformerDecoder(nn.Module): # Decoder
"""
The Transformer decoder from "Attention is All You Need".
@@ -403,44 +432,6 @@ class TransformerDecoder(nn.Module): # Decoder
"""
decoder_type = 'transformer'

class TransformerDecoderState:

def __init__(self, src):
self.src = src
self.previous_input = None
self.previous_layer_inputs = None
self.cache = None

def update_state(self, new_input, previous_layer_inputs):
self.previous_input = new_input
self.previous_layer_inputs = previous_layer_inputs
self.cache = None

def _init_cache(self, num_layers):
self.cache = {}
for num in range(num_layers):
layer_cache = {
'memory_keys': None,
'memory_values': None,
'self_keys': None,
'self_values': None
}
self.cache['layer_{}'.format(num)] = layer_cache

def map_batch_fn(self, fn):

def _recursive_map(struct, batch_dim=0):
for k, v in struct.items():
if v is not None:
if isinstance(v, dict):
_recursive_map(v)
else:
struct[k] = fn(v, batch_dim)

self.src = fn(self.src, 0)
if self.cache is not None:
_recursive_map(self.cache)

def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings):
super().__init__()

@@ -458,13 +449,13 @@ class TransformerDecoder(nn.Module): # Decoder
self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
self.state = None

def init_state(self, src, with_cache=False):
self.state = self.TransformerDecoderState(src)
if with_cache:
self.state._init_cache(self.num_layers)
def forward(self, tgt, memory_bank, step=None, memory_masks=None):
src_words = self.state.src
def forward(self,
state: TransformerDecoderState,
tgt: Tensor,
memory_bank: Tensor,
step: int = None,
memory_masks: Tensor = None):
src_words = state.src
tgt_words = tgt
src_batch, src_len = src_words.size()
tgt_batch, tgt_len = tgt_words.size()
@@ -487,33 +478,36 @@ class TransformerDecoder(nn.Module): # Decoder
src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \
.expand(src_batch, tgt_len, src_len)

if self.state.cache is None:
if state.cache is None:
saved_inputs = []
attns = []
for i in range(self.num_layers):
prev_layer_input = None
if self.state.cache is None:
if self.state.previous_input is not None:
prev_layer_input = self.state.previous_layer_inputs[i]
if state.cache is None:
if state.previous_input is not None:
prev_layer_input = state.previous_layer_inputs[i]
output, attn, all_input \
= self.transformer_layers[i](output, src_memory_bank, src_pad_mask, tgt_pad_mask,
previous_input=prev_layer_input,
layer_cache=self.state.cache['layer_{}'.format(i)]
if self.state.cache is not None else None, step=step)
if self.state.cache is None:
= self.transformer_layers[i](
output, src_memory_bank,
src_pad_mask, tgt_pad_mask,
previous_input=prev_layer_input,
layer_cache=state.cache['layer_{}'.format(i)]
if state.cache is not None else None,
step=step)
if state.cache is None:
saved_inputs.append(all_input)
attns.append(attn)

if self.state.cache is None:
if state.cache is None:
saved_inputs = torch.stack(saved_inputs)

output = self.layer_norm(output)

# Process the result and update the attentions.
if self.state.cache is None:
self.state.update_state(tgt, saved_inputs)
if state.cache is None:
state.update_state(tgt, saved_inputs)

return output, attns
return output, attns, state


class PalmPointerGenerator(nn.Module):
@@ -570,10 +564,11 @@ class AbsSummarizer(PalmPreTrainedModel): # Model
if (config.max_pos > 512):
my_pos_embeddings = nn.Embedding(
config.max_pos, self.bert.model.config.hidden_size)
my_pos_embeddings.weight.data[:512] = \
self.bert.embeddings.position_embeddings.weight.data
my_pos_embeddings.weight.data[512:] = \
self.bert.embeddings.position_embeddings.weight.data[-1][None, :].repeat(config.max_pos - 512, 1)
my_pos_embeddings.weight.data[:
512] = self.bert.embeddings.position_embeddings.weight.data
my_pos_embeddings.weight.data[
512:] = self.bert.embeddings.position_embeddings.weight.data[
-1][None, :].repeat(config.max_pos - 512, 1)
self.bert.model.embeddings.position_embeddings = my_pos_embeddings
self.vocab_size = self.bert.config.vocab_size
tgt_embeddings = nn.Embedding(
@@ -633,8 +628,8 @@ class AbsSummarizer(PalmPreTrainedModel): # Model

def forward(self, src, tgt, mask_src):
top_vec, _ = self.bert(src, mask_src, return_dict=False)
self.decoder.init_state(src)
decoder_outputs, attns = self.decoder(tgt[:, :-1], top_vec)
state = TransformerDecoderState(src)
decoder_outputs, attns, _ = self.decoder(state, tgt[:, :-1], top_vec)
return decoder_outputs, attns[-1], top_vec


@@ -776,9 +771,9 @@ class Translator(nn.Module):
translation_batch['predictions']))
batch_size = batch.batch_size

preds, pred_score, _, tgt_str, src, src_str = \
translation_batch['predictions'], translation_batch['scores'], translation_batch['gold_score'], \
batch.tgt_str, batch.src, batch.src_str
preds, pred_score, tgt_str, src, src_str = translation_batch[
'predictions'], translation_batch[
'scores'], batch.tgt_str, batch.src, batch.src_str
query_id = batch.query_id
'''
try:
@@ -903,17 +898,13 @@ class Translator(nn.Module):
'</s>', '').replace('<unk>', ' ').strip()
if (self.args.recall_eval):
_pred_str = ''
# gap = 1e3
for sent in pred_str.split('<q>'):
can_pred_str = _pred_str + '<q>' + sent.strip()
# can_gap = math.fabs(len(_pred_str.split()) - len(gold_str.split()))
# if(can_gap>=gap):
if len(can_pred_str.split()) >= len(
gold_str.split()) + 10:
pred_str = _pred_str
break
else:
# gap = can_gap
_pred_str = can_pred_str

if self.args.dataset == 'marco' or self.args.dataset == 'squad' or self.args.dataset == 'qg_ranking':
@@ -967,10 +958,6 @@ class Translator(nn.Module):
pred_str = [pred_str]
pred_dict[query_id] = normalize([pred_str[0]])
ref_dict[query_id] = normalize([gold_str])
# pred_str_list = [src] + pred_str
# self.can_out_file.write("\t".join(pred_str_list)+"\n")
# self.can_out_file.write("\t".join(pred_str_list)+"\n")
# self.gold_out_file.write("\t".join([src, pred_str[0], gold_str])+"\n")
self.pred_json_score_out_file.write(
'\t'.join([str(query_id), src, gold_str, pred_str[0]])
+ '\n')
@@ -1027,8 +1014,6 @@ class Translator(nn.Module):
preds[idx] = '。'
return preds, labels

# bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
# self.logger.info('Dev eval result: {}'.format(bleu_rouge))
pred_results, gold_results = postprocess_text(
pred_results, gold_results)
pred_dict = {str(i): tmp for i, tmp in enumerate(pred_results)}
@@ -1037,8 +1022,6 @@ class Translator(nn.Module):
print(bleu_rouge)
# unreachable
elif self.args.dataset == 'dureader' or self.args.dataset == 'paraphrase':
# bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict)
# self.logger.info('Dev eval result: {}'.format(bleu_rouge))
pred_results, gold_results = postprocess_text(
pred_results, gold_results)
bleu_score = cal_bleu(pred_results, gold_results)
@@ -1134,11 +1117,11 @@ class Translator(nn.Module):
mask_src = batch.mask_src

src_features, _ = self.model.bert(src, mask_src, return_dict=False)
self.model.decoder.init_state(src, with_cache=True)
state = TransformerDecoderState(src, self.model.decoder.num_layers)
device = src_features.device

# Tile states and memory beam_size times.
self.model.decoder.state.map_batch_fn(
state.map_batch_fn(
lambda state, dim: self._tile(state, beam_size, dim=dim))
src_features = self._tile(src_features, beam_size, dim=0)
batch_offset = torch.arange(
@@ -1174,8 +1157,8 @@ class Translator(nn.Module):

# Decoder forward.
decoder_input = decoder_input.transpose(0, 1)
dec_out, attns = self.model.decoder(
decoder_input, src_features, step=step)
dec_out, attns, state = self.model.decoder(
state, decoder_input, src_features, step=step)

# Generator forward.
log_probs = self.generator.forward(
@@ -1188,7 +1171,6 @@ class Translator(nn.Module):
# Multiply probs by the beam probability.

length_penalty = ((5.0 + (step + 1)) / 6.0)**self.alpha
# '''
if self.args.sample_topk:
temperature = self.args.temperature
_scores = log_probs / temperature
@@ -1211,13 +1193,11 @@ class Translator(nn.Module):
_scores = _scores / length_penalty
topk_scores = torch.gather(
_scores, -1, topk_ids) # (batch_size * num_beams, 2)
# log_probs += # (batch_size * num_beams, 2)
# Match shape of greedy beam search
topk_ids = topk_ids.view(
-1, beam_size) # (batch_size, 2 * num_beams)
topk_scores = topk_scores.view(
-1, beam_size) # (batch_size, 2 * num_beams)
# '''
else:
log_probs += topk_log_probs.view(-1).unsqueeze(1)
curr_scores = log_probs / length_penalty
@@ -1231,7 +1211,6 @@ class Translator(nn.Module):
fail = False
words = [int(w) for w in alive_seq[i]]
if self.args.encoder == 'roberta':
# words = [self.vocab.convert_ids_to_tokens[w] for w in words]
words = self.vocab.decode(words).strip().split()
else:
words = [
@@ -1252,7 +1231,6 @@ class Translator(nn.Module):
topk_log_probs = topk_scores * length_penalty

# Resolve beam origin and true word ids.
# topk_beam_index = topk_ids.div(vocab_size)
topk_beam_index = topk_ids // vocab_size
topk_ids = topk_ids.fmod(vocab_size)

@@ -1313,7 +1291,7 @@ class Translator(nn.Module):
# Reorder states.
select_indices = batch_index.view(-1)
src_features = src_features.index_select(0, select_indices)
self.model.decoder.state.map_batch_fn(
state.map_batch_fn(
lambda state, dim: state.index_select(dim, select_indices))

return results


Loading…
Cancel
Save