Browse Source

updated

master^2
shuaigezhu 3 years ago
parent
commit
1ab8a1f764
5 changed files with 11 additions and 10 deletions
  1. +3
    -1
      modelscope/models/nlp/codegeex/codegeex_for_code_translation.py
  2. +3
    -2
      modelscope/models/nlp/codegeex/inference.py
  3. +2
    -1
      modelscope/models/nlp/codegeex/tokenizer.py
  4. +2
    -2
      modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py
  5. +1
    -4
      tests/pipelines/test_CodeGeeX_code_translation.py

+ 3
- 1
modelscope/models/nlp/codegeex/codegeex_for_code_translation.py View File

@@ -98,7 +98,9 @@ class CodeGeeXForCodeTranslation(TorchModel):
generated_code = tokenizer.decode_code(
generated_tokens_[n_token_prompt:])
generated_code = ''.join(generated_code)
logger.info('================================= Generated code:')
logger.info(
'================================= Generated code:'
)
logger.info(generated_code)
if all(is_finished):
break


+ 3
- 2
modelscope/models/nlp/codegeex/inference.py View File

@@ -1,8 +1,9 @@
# Copyright (c) 2022 Zhipu.AI

from typing import List

import torch
import torch.nn.functional as F
from typing import List


def get_ltor_masks_and_position_ids(
@@ -124,7 +125,7 @@ def pad_batch(batch, pad_id, seq_length):
tokens.extend([pad_id] * (seq_length - context_length))
context_lengths.append(context_length)
return batch, context_lengths

def get_token_stream(
model,


+ 2
- 1
modelscope/models/nlp/codegeex/tokenizer.py View File

@@ -1,8 +1,9 @@
# Copyright (c) 2022 Zhipu.AI
from typing import List, Union

import torch
from transformers import AutoTokenizer
from transformers.models.gpt2 import GPT2TokenizerFast
from typing import List, Union


def encode_whitespaces(text, start_extra_id: int, max_len: int):


+ 2
- 2
modelscope/pipelines/nlp/codegeex_code_translation_pipeline.py View File

@@ -28,9 +28,9 @@ class CodeGeeXCodeTranslationPipeline(Pipeline):
self.model.cuda()

super().__init__(model=model, **kwargs)
def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]:
return inputs
return inputs

# define the forward pass
def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]:


+ 1
- 4
tests/pipelines/test_CodeGeeX_code_translation.py View File

@@ -17,10 +17,7 @@ class CodeGeeXCodeTranslationTest(unittest.TestCase, DemoCompatibilityCheck):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_CodeGeeX_with_name(self):
model = 'ZhipuAI/CodeGeeX-Code-Translation-13B'
pipe = pipeline(
task=Tasks.code_translation,
model=model
)
pipe = pipeline(task=Tasks.code_translation, model=model)
inputs = {
'prompt': 'for i in range(10):\n\tprint(i)\n',
'source language': 'Python',


Loading…
Cancel
Save