| @@ -98,7 +98,9 @@ class CodeGeeXForCodeTranslation(TorchModel): | |||||
| generated_code = tokenizer.decode_code( | generated_code = tokenizer.decode_code( | ||||
| generated_tokens_[n_token_prompt:]) | generated_tokens_[n_token_prompt:]) | ||||
| generated_code = ''.join(generated_code) | generated_code = ''.join(generated_code) | ||||
| logger.info('================================= Generated code:') | |||||
| logger.info( | |||||
| '================================= Generated code:' | |||||
| ) | |||||
| logger.info(generated_code) | logger.info(generated_code) | ||||
| if all(is_finished): | if all(is_finished): | ||||
| break | break | ||||
| @@ -1,8 +1,9 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | # Copyright (c) 2022 Zhipu.AI | ||||
| from typing import List | |||||
| import torch | import torch | ||||
| import torch.nn.functional as F | import torch.nn.functional as F | ||||
| from typing import List | |||||
| def get_ltor_masks_and_position_ids( | def get_ltor_masks_and_position_ids( | ||||
| @@ -124,7 +125,7 @@ def pad_batch(batch, pad_id, seq_length): | |||||
| tokens.extend([pad_id] * (seq_length - context_length)) | tokens.extend([pad_id] * (seq_length - context_length)) | ||||
| context_lengths.append(context_length) | context_lengths.append(context_length) | ||||
| return batch, context_lengths | return batch, context_lengths | ||||
| def get_token_stream( | def get_token_stream( | ||||
| model, | model, | ||||
| @@ -1,8 +1,9 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | # Copyright (c) 2022 Zhipu.AI | ||||
| from typing import List, Union | |||||
| import torch | import torch | ||||
| from transformers import AutoTokenizer | from transformers import AutoTokenizer | ||||
| from transformers.models.gpt2 import GPT2TokenizerFast | from transformers.models.gpt2 import GPT2TokenizerFast | ||||
| from typing import List, Union | |||||
| def encode_whitespaces(text, start_extra_id: int, max_len: int): | def encode_whitespaces(text, start_extra_id: int, max_len: int): | ||||
| @@ -28,9 +28,9 @@ class CodeGeeXCodeTranslationPipeline(Pipeline): | |||||
| self.model.cuda() | self.model.cuda() | ||||
| super().__init__(model=model, **kwargs) | super().__init__(model=model, **kwargs) | ||||
| def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | ||||
| return inputs | |||||
| return inputs | |||||
| # define the forward pass | # define the forward pass | ||||
| def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | ||||
| @@ -17,10 +17,7 @@ class CodeGeeXCodeTranslationTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') | ||||
| def test_run_with_CodeGeeX_with_name(self): | def test_run_with_CodeGeeX_with_name(self): | ||||
| model = 'ZhipuAI/CodeGeeX-Code-Translation-13B' | model = 'ZhipuAI/CodeGeeX-Code-Translation-13B' | ||||
| pipe = pipeline( | |||||
| task=Tasks.code_translation, | |||||
| model=model | |||||
| ) | |||||
| pipe = pipeline(task=Tasks.code_translation, model=model) | |||||
| inputs = { | inputs = { | ||||
| 'prompt': 'for i in range(10):\n\tprint(i)\n', | 'prompt': 'for i in range(10):\n\tprint(i)\n', | ||||
| 'source language': 'Python', | 'source language': 'Python', | ||||