From 4fb4947397a097e024ab8cd3163b9c1d25c86842 Mon Sep 17 00:00:00 2001 From: "hemu.zp" Date: Thu, 4 Aug 2022 23:06:44 +0800 Subject: [PATCH] [to #42322933] Fix palm concurrency MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修复 palm 模型部署并发报错问题 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9649010 * fix palm concurrency --- .../models/nlp/palm_v2/modeling_palm.py | 208 ++++++++---------- 1 file changed, 93 insertions(+), 115 deletions(-) diff --git a/modelscope/models/nlp/palm_v2/modeling_palm.py b/modelscope/models/nlp/palm_v2/modeling_palm.py index 127b5440..1cbf4f58 100644 --- a/modelscope/models/nlp/palm_v2/modeling_palm.py +++ b/modelscope/models/nlp/palm_v2/modeling_palm.py @@ -4,20 +4,19 @@ import math import os import subprocess from dataclasses import dataclass -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import json import numpy as np import torch import torch.nn.functional as F -from torch import nn +from torch import Tensor, nn from torch.nn.init import xavier_uniform_ from transformers import (BertConfig, BertModel, BertTokenizer, RobertaConfig, RobertaModel, RobertaTokenizer) from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel -from modelscope.outputs import OutputKeys from modelscope.utils import logger as logging from .configuration_palm import PalmConfig from .dureader_eval import compute_bleu_rouge, normalize @@ -142,35 +141,27 @@ class MultiHeadedAttention(nn.Module): # SelfAttention key = shape(key) value = shape(value) - if layer_cache is not None: - device = key.device - if layer_cache['self_keys'] is not None: - key = torch.cat( - (layer_cache['self_keys'].to(device), key), dim=2) - if layer_cache['self_values'] is not None: - value = torch.cat( - (layer_cache['self_values'].to(device), value), - dim=2) - layer_cache['self_keys'] = key - layer_cache['self_values'] = value + device = key.device + if layer_cache['self_keys'] is not None: + key = torch.cat((layer_cache['self_keys'].to(device), key), + dim=2) + if layer_cache['self_values'] is not None: + value = torch.cat( + (layer_cache['self_values'].to(device), value), dim=2) + layer_cache['self_keys'] = key + layer_cache['self_values'] = value elif type == 'context': query = self.linear_query(query) - if layer_cache is not None: - if layer_cache['memory_keys'] is None: - key, value = self.linear_keys(key), self.linear_values( - value) - key = shape(key) - value = shape(value) - else: - key, value = layer_cache['memory_keys'], layer_cache[ - 'memory_values'] - layer_cache['memory_keys'] = key - layer_cache['memory_values'] = value - else: + if layer_cache['memory_keys'] is None: key, value = self.linear_keys(key), self.linear_values( value) key = shape(key) value = shape(value) + else: + key, value = layer_cache['memory_keys'], layer_cache[ + 'memory_values'] + layer_cache['memory_keys'] = key + layer_cache['memory_values'] = value else: key = self.linear_keys(key) value = self.linear_values(value) @@ -372,6 +363,44 @@ class PositionalEncoding(nn.Module): return self.pe[:, :emb.size(1)] +class TransformerDecoderState: + + def __init__(self, src: Tensor, cache_num_layers: int = -1): + self.src: Tensor = src + self.previous_input: Tensor = None + self.previous_layer_inputs: Tensor = None + self.cache: Optional[Dict[str, Any]] = None + if cache_num_layers != -1: + self._init_cache(cache_num_layers) + + def update_state(self, new_input, previous_layer_inputs): + self.previous_input = new_input + self.previous_layer_inputs = previous_layer_inputs + self.cache = None + + def _init_cache(self, num_layers): + self.cache = {} + for num in range(num_layers): + layer_cache = {'memory_keys': None, 'memory_values': None} + layer_cache['self_keys'] = None + layer_cache['self_values'] = None + self.cache['layer_{}'.format(num)] = layer_cache + + def map_batch_fn(self, fn): + + def _recursive_map(struct, batch_dim=0): + for k, v in struct.items(): + if v is not None: + if isinstance(v, dict): + _recursive_map(v) + else: + struct[k] = fn(v, batch_dim) + + self.src = fn(self.src, 0) + if self.cache is not None: + _recursive_map(self.cache) + + class TransformerDecoder(nn.Module): # Decoder """ The Transformer decoder from "Attention is All You Need". @@ -403,44 +432,6 @@ class TransformerDecoder(nn.Module): # Decoder """ decoder_type = 'transformer' - class TransformerDecoderState: - - def __init__(self, src): - self.src = src - self.previous_input = None - self.previous_layer_inputs = None - self.cache = None - - def update_state(self, new_input, previous_layer_inputs): - self.previous_input = new_input - self.previous_layer_inputs = previous_layer_inputs - self.cache = None - - def _init_cache(self, num_layers): - self.cache = {} - for num in range(num_layers): - layer_cache = { - 'memory_keys': None, - 'memory_values': None, - 'self_keys': None, - 'self_values': None - } - self.cache['layer_{}'.format(num)] = layer_cache - - def map_batch_fn(self, fn): - - def _recursive_map(struct, batch_dim=0): - for k, v in struct.items(): - if v is not None: - if isinstance(v, dict): - _recursive_map(v) - else: - struct[k] = fn(v, batch_dim) - - self.src = fn(self.src, 0) - if self.cache is not None: - _recursive_map(self.cache) - def __init__(self, num_layers, d_model, heads, d_ff, dropout, embeddings): super().__init__() @@ -458,13 +449,13 @@ class TransformerDecoder(nn.Module): # Decoder self.layer_norm = nn.LayerNorm(d_model, eps=1e-6) self.state = None - def init_state(self, src, with_cache=False): - self.state = self.TransformerDecoderState(src) - if with_cache: - self.state._init_cache(self.num_layers) - - def forward(self, tgt, memory_bank, step=None, memory_masks=None): - src_words = self.state.src + def forward(self, + state: TransformerDecoderState, + tgt: Tensor, + memory_bank: Tensor, + step: int = None, + memory_masks: Tensor = None): + src_words = state.src tgt_words = tgt src_batch, src_len = src_words.size() tgt_batch, tgt_len = tgt_words.size() @@ -487,33 +478,36 @@ class TransformerDecoder(nn.Module): # Decoder src_pad_mask = src_words.data.eq(padding_idx).unsqueeze(1) \ .expand(src_batch, tgt_len, src_len) - if self.state.cache is None: + if state.cache is None: saved_inputs = [] attns = [] for i in range(self.num_layers): prev_layer_input = None - if self.state.cache is None: - if self.state.previous_input is not None: - prev_layer_input = self.state.previous_layer_inputs[i] + if state.cache is None: + if state.previous_input is not None: + prev_layer_input = state.previous_layer_inputs[i] output, attn, all_input \ - = self.transformer_layers[i](output, src_memory_bank, src_pad_mask, tgt_pad_mask, - previous_input=prev_layer_input, - layer_cache=self.state.cache['layer_{}'.format(i)] - if self.state.cache is not None else None, step=step) - if self.state.cache is None: + = self.transformer_layers[i]( + output, src_memory_bank, + src_pad_mask, tgt_pad_mask, + previous_input=prev_layer_input, + layer_cache=state.cache['layer_{}'.format(i)] + if state.cache is not None else None, + step=step) + if state.cache is None: saved_inputs.append(all_input) attns.append(attn) - if self.state.cache is None: + if state.cache is None: saved_inputs = torch.stack(saved_inputs) output = self.layer_norm(output) # Process the result and update the attentions. - if self.state.cache is None: - self.state.update_state(tgt, saved_inputs) + if state.cache is None: + state.update_state(tgt, saved_inputs) - return output, attns + return output, attns, state class PalmPointerGenerator(nn.Module): @@ -570,10 +564,11 @@ class AbsSummarizer(PalmPreTrainedModel): # Model if (config.max_pos > 512): my_pos_embeddings = nn.Embedding( config.max_pos, self.bert.model.config.hidden_size) - my_pos_embeddings.weight.data[:512] = \ - self.bert.embeddings.position_embeddings.weight.data - my_pos_embeddings.weight.data[512:] = \ - self.bert.embeddings.position_embeddings.weight.data[-1][None, :].repeat(config.max_pos - 512, 1) + my_pos_embeddings.weight.data[: + 512] = self.bert.embeddings.position_embeddings.weight.data + my_pos_embeddings.weight.data[ + 512:] = self.bert.embeddings.position_embeddings.weight.data[ + -1][None, :].repeat(config.max_pos - 512, 1) self.bert.model.embeddings.position_embeddings = my_pos_embeddings self.vocab_size = self.bert.config.vocab_size tgt_embeddings = nn.Embedding( @@ -633,8 +628,8 @@ class AbsSummarizer(PalmPreTrainedModel): # Model def forward(self, src, tgt, mask_src): top_vec, _ = self.bert(src, mask_src, return_dict=False) - self.decoder.init_state(src) - decoder_outputs, attns = self.decoder(tgt[:, :-1], top_vec) + state = TransformerDecoderState(src) + decoder_outputs, attns, _ = self.decoder(state, tgt[:, :-1], top_vec) return decoder_outputs, attns[-1], top_vec @@ -776,9 +771,9 @@ class Translator(nn.Module): translation_batch['predictions'])) batch_size = batch.batch_size - preds, pred_score, _, tgt_str, src, src_str = \ - translation_batch['predictions'], translation_batch['scores'], translation_batch['gold_score'], \ - batch.tgt_str, batch.src, batch.src_str + preds, pred_score, tgt_str, src, src_str = translation_batch[ + 'predictions'], translation_batch[ + 'scores'], batch.tgt_str, batch.src, batch.src_str query_id = batch.query_id ''' try: @@ -903,17 +898,13 @@ class Translator(nn.Module): '', '').replace('', ' ').strip() if (self.args.recall_eval): _pred_str = '' - # gap = 1e3 for sent in pred_str.split(''): can_pred_str = _pred_str + '' + sent.strip() - # can_gap = math.fabs(len(_pred_str.split()) - len(gold_str.split())) - # if(can_gap>=gap): if len(can_pred_str.split()) >= len( gold_str.split()) + 10: pred_str = _pred_str break else: - # gap = can_gap _pred_str = can_pred_str if self.args.dataset == 'marco' or self.args.dataset == 'squad' or self.args.dataset == 'qg_ranking': @@ -967,10 +958,6 @@ class Translator(nn.Module): pred_str = [pred_str] pred_dict[query_id] = normalize([pred_str[0]]) ref_dict[query_id] = normalize([gold_str]) - # pred_str_list = [src] + pred_str - # self.can_out_file.write("\t".join(pred_str_list)+"\n") - # self.can_out_file.write("\t".join(pred_str_list)+"\n") - # self.gold_out_file.write("\t".join([src, pred_str[0], gold_str])+"\n") self.pred_json_score_out_file.write( '\t'.join([str(query_id), src, gold_str, pred_str[0]]) + '\n') @@ -1027,8 +1014,6 @@ class Translator(nn.Module): preds[idx] = '。' return preds, labels - # bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) - # self.logger.info('Dev eval result: {}'.format(bleu_rouge)) pred_results, gold_results = postprocess_text( pred_results, gold_results) pred_dict = {str(i): tmp for i, tmp in enumerate(pred_results)} @@ -1037,8 +1022,6 @@ class Translator(nn.Module): print(bleu_rouge) # unreachable elif self.args.dataset == 'dureader' or self.args.dataset == 'paraphrase': - # bleu_rouge = compute_bleu_rouge(pred_dict, ref_dict) - # self.logger.info('Dev eval result: {}'.format(bleu_rouge)) pred_results, gold_results = postprocess_text( pred_results, gold_results) bleu_score = cal_bleu(pred_results, gold_results) @@ -1134,11 +1117,11 @@ class Translator(nn.Module): mask_src = batch.mask_src src_features, _ = self.model.bert(src, mask_src, return_dict=False) - self.model.decoder.init_state(src, with_cache=True) + state = TransformerDecoderState(src, self.model.decoder.num_layers) device = src_features.device # Tile states and memory beam_size times. - self.model.decoder.state.map_batch_fn( + state.map_batch_fn( lambda state, dim: self._tile(state, beam_size, dim=dim)) src_features = self._tile(src_features, beam_size, dim=0) batch_offset = torch.arange( @@ -1174,8 +1157,8 @@ class Translator(nn.Module): # Decoder forward. decoder_input = decoder_input.transpose(0, 1) - dec_out, attns = self.model.decoder( - decoder_input, src_features, step=step) + dec_out, attns, state = self.model.decoder( + state, decoder_input, src_features, step=step) # Generator forward. log_probs = self.generator.forward( @@ -1188,7 +1171,6 @@ class Translator(nn.Module): # Multiply probs by the beam probability. length_penalty = ((5.0 + (step + 1)) / 6.0)**self.alpha - # ''' if self.args.sample_topk: temperature = self.args.temperature _scores = log_probs / temperature @@ -1211,13 +1193,11 @@ class Translator(nn.Module): _scores = _scores / length_penalty topk_scores = torch.gather( _scores, -1, topk_ids) # (batch_size * num_beams, 2) - # log_probs += # (batch_size * num_beams, 2) # Match shape of greedy beam search topk_ids = topk_ids.view( -1, beam_size) # (batch_size, 2 * num_beams) topk_scores = topk_scores.view( -1, beam_size) # (batch_size, 2 * num_beams) - # ''' else: log_probs += topk_log_probs.view(-1).unsqueeze(1) curr_scores = log_probs / length_penalty @@ -1231,7 +1211,6 @@ class Translator(nn.Module): fail = False words = [int(w) for w in alive_seq[i]] if self.args.encoder == 'roberta': - # words = [self.vocab.convert_ids_to_tokens[w] for w in words] words = self.vocab.decode(words).strip().split() else: words = [ @@ -1252,7 +1231,6 @@ class Translator(nn.Module): topk_log_probs = topk_scores * length_penalty # Resolve beam origin and true word ids. - # topk_beam_index = topk_ids.div(vocab_size) topk_beam_index = topk_ids // vocab_size topk_ids = topk_ids.fmod(vocab_size) @@ -1313,7 +1291,7 @@ class Translator(nn.Module): # Reorder states. select_indices = batch_index.view(-1) src_features = src_features.index_select(0, select_indices) - self.model.decoder.state.map_batch_fn( + state.map_batch_fn( lambda state, dim: state.index_select(dim, select_indices)) return results