|
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
-
- """
- TopK for text generation
- """
-
- import numpy as np
- import mindspore.common.dtype as mstype
- from mindspore.common.tensor import Tensor
-
- def generate(model, origin_inputs, seq_length, end_token=50256):
- """
- TopK for text generation
-
- Inputs:
- model: the model for inferencing
- origin_inputs: the original inputs based on which the model will continue writing
- seq_length: seq_length for the model
- end_token: end of sentence token id
-
- Returns:
- outputs: the ids for the generated text
- """
- TOPK = 5
- seq_length = seq_length
- bs, valid_length = origin_inputs.shape
- pad_length = seq_length - origin_inputs.shape[-1]
- input_ids = np.pad(origin_inputs, ((0, 0), (0, pad_length)), 'constant', constant_values=(0, 0))
- print("input_ids is ", input_ids)
- while valid_length < seq_length:
- inputs = Tensor(input_ids, mstype.int32)
- logits = model(inputs).asnumpy()
- logits = logits.reshape(bs, seq_length, -1)
- probs = logits[0, valid_length-1, :]
- p_args = probs.argsort()[::-1][:TOPK]
-
- p = probs[p_args]
- p = p / sum(p)
- target_index = np.random.choice(len(p), p=p)
- if p_args[target_index] == end_token or valid_length == seq_length-1:
- outputs = input_ids
- break
- input_ids[0][valid_length] = p_args[target_index]
- valid_length += 1
- length = np.sum(outputs != 0)
- outputs = outputs[0][:length]
- return outputs
|