| @@ -257,6 +257,7 @@ class Pipelines(object): | |||
| feature_extraction = 'feature-extraction' | |||
| mglm_text_summarization = 'mglm-text-summarization' | |||
| codegeex_code_translation = 'codegeex-code-translation' | |||
| codegeex_code_generation = 'codegeex-code-generation' | |||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | |||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | |||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | |||
| @@ -384,7 +385,6 @@ class Preprocessors(object): | |||
| document_segmentation = 'document-segmentation' | |||
| feature_extraction = 'feature-extraction' | |||
| mglm_summarization = 'mglm-summarization' | |||
| codegeex = 'codegeex' | |||
| sentence_piece = 'sentence-piece' | |||
| # audio preprocessor | |||
| @@ -36,7 +36,7 @@ if TYPE_CHECKING: | |||
| ) | |||
| from .T5 import T5ForConditionalGeneration | |||
| from .mglm import MGLMForTextSummarization | |||
| from .codegeex import CodeGeeXForCodeTranslation | |||
| from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration | |||
| from .task_models import ( | |||
| FeatureExtractionModel, | |||
| InformationExtractionModel, | |||
| @@ -109,7 +109,7 @@ else: | |||
| 'sentence_embedding': ['SentenceEmbedding'], | |||
| 'T5': ['T5ForConditionalGeneration'], | |||
| 'mglm': ['MGLMForTextSummarization'], | |||
| 'codegeex': ['CodeGeeXForCodeTranslation'], | |||
| 'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | |||
| 'gpt_neo': ['GPTNeoModel'], | |||
| 'bloom': ['BloomModel'], | |||
| } | |||
| @@ -6,9 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||
| if TYPE_CHECKING: | |||
| from .codegeex_for_code_translation import CodeGeeXForCodeTranslation | |||
| from .codegeex_for_code_generation import CodeGeeXForCodeGeneration | |||
| else: | |||
| _import_structure = { | |||
| 'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'], | |||
| 'codegeex_for_code_generation': ['CodeGeeXForCodeGeneration'], | |||
| } | |||
| import sys | |||
| @@ -0,0 +1,111 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| import copy | |||
| from typing import Any, Dict | |||
| import torch | |||
| from modelscope.metainfo import Models | |||
| from modelscope.models.base import TorchModel | |||
| from modelscope.models.builder import MODELS | |||
| from modelscope.outputs import OutputKeys | |||
| from modelscope.utils.constant import Tasks | |||
| from modelscope.utils.logger import get_logger | |||
| from .codegeex import CodeGeeXModel | |||
| from .inference import get_token_stream | |||
| from .tokenizer import CodeGeeXTokenizer | |||
| def model_provider(): | |||
| """Build the model.""" | |||
| hidden_size = 5120 | |||
| num_attention_heads = 40 | |||
| num_layers = 39 | |||
| padded_vocab_size = 52224 | |||
| max_position_embeddings = 2048 | |||
| model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads, | |||
| padded_vocab_size, max_position_embeddings) | |||
| return model | |||
| @MODELS.register_module(Tasks.code_generation, module_name=Models.codegeex) | |||
| class CodeGeeXForCodeGeneration(TorchModel): | |||
| def __init__(self, model_dir: str, *args, **kwargs): | |||
| """initialize the fast poem model from the `model_dir` path. | |||
| Args: | |||
| model_dir (str): the model path. | |||
| """ | |||
| super().__init__(model_dir, *args, **kwargs) | |||
| logger = get_logger() | |||
| # loading tokenizer | |||
| logger.info('Loading tokenizer ...') | |||
| self.tokenizer = CodeGeeXTokenizer( | |||
| tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b') | |||
| # loading model | |||
| state_dict_path = model_dir + '/ckpt_ms_213000_fp32_52224.pt' | |||
| logger.info('Loading state dict ...') | |||
| state_dict = torch.load(state_dict_path, map_location='cpu') | |||
| state_dict = state_dict['module'] | |||
| logger.info('Building CodeGeeX model ...') | |||
| self.model = model_provider() | |||
| self.model.load_state_dict(state_dict) | |||
| self.model.eval() | |||
| self.model.half() | |||
| self.model.cuda() | |||
| def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||
| micro_batch_size = 1 | |||
| seq_length = 2048 | |||
| out_seq_length = 256 | |||
| bad_ids = None | |||
| lang = input['language'] | |||
| prompt = input['prompt'] | |||
| prompt = f"# language: {lang}\n{prompt}" | |||
| logger = get_logger() | |||
| tokenizer = self.tokenizer | |||
| model = self.model | |||
| for prompt in [prompt]: | |||
| tokens = tokenizer.encode_code(prompt) | |||
| n_token_prompt = len(tokens) | |||
| token_stream = get_token_stream( | |||
| model, | |||
| tokenizer, | |||
| seq_length, | |||
| out_seq_length, | |||
| [copy.deepcopy(tokens) for _ in range(micro_batch_size)], | |||
| micro_batch_size=micro_batch_size, | |||
| bad_ids=bad_ids, | |||
| topk=1, | |||
| topp=0.9, | |||
| temperature=0.9, | |||
| greedy=True | |||
| ) | |||
| is_finished = [False for _ in range(micro_batch_size)] | |||
| for i, generated in enumerate(token_stream): | |||
| generated_tokens = generated[0] | |||
| for j in range(micro_batch_size): | |||
| if is_finished[j]: | |||
| continue | |||
| if generated_tokens[j].cpu().numpy( | |||
| )[-1] == tokenizer.eos_token_id or len( | |||
| generated_tokens[j]) >= out_seq_length: | |||
| is_finished[j] = True | |||
| generated_tokens_ = generated_tokens[j].cpu().numpy( | |||
| ).tolist() | |||
| generated_code = tokenizer.decode_code( | |||
| generated_tokens_[n_token_prompt:]) | |||
| generated_code = ''.join(generated_code) | |||
| logger.info( | |||
| '================================= Generated code:' | |||
| ) | |||
| logger.info(generated_code) | |||
| if all(is_finished): | |||
| break | |||
| logger.info('Generation finished.') | |||
| return {OutputKeys.TEXT: generated_code} | |||
| @@ -33,6 +33,7 @@ if TYPE_CHECKING: | |||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | |||
| from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | |||
| from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | |||
| from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline | |||
| from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | |||
| WordSegmentationThaiPipeline | |||
| @@ -76,6 +77,8 @@ else: | |||
| 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | |||
| 'codegeex_code_translation_pipeline': | |||
| ['CodeGeeXCodeTranslationPipeline'], | |||
| 'codegeex_code_generation_pipeline': | |||
| ['CodeGeeXCodeGenerationPipeline'], | |||
| 'multilingual_word_segmentation_pipeline': [ | |||
| 'MultilingualWordSegmentationPipeline', | |||
| 'WordSegmentationThaiPipeline' | |||
| @@ -0,0 +1,48 @@ | |||
| # Copyright (c) 2022 Zhipu.AI | |||
| from typing import Any, Dict, Union | |||
| from modelscope.metainfo import Pipelines | |||
| from modelscope.models.nlp import CodeGeeXForCodeGeneration | |||
| from modelscope.pipelines.base import Pipeline | |||
| from modelscope.pipelines.builder import PIPELINES | |||
| from modelscope.preprocessors import Preprocessor | |||
| from modelscope.utils.constant import Tasks | |||
| @PIPELINES.register_module( | |||
| group_key=Tasks.code_generation, | |||
| module_name=Pipelines.codegeex_code_generation) | |||
| class CodeGeeXCodeGenerationPipeline(Pipeline): | |||
| def __init__(self, | |||
| model: Union[CodeGeeXForCodeGeneration, str], | |||
| preprocessor: [Preprocessor] = None, | |||
| *args, | |||
| **kwargs): | |||
| model = CodeGeeXForCodeGeneration(model) if isinstance(model, | |||
| str) else model | |||
| self.model = model | |||
| self.model.eval() | |||
| self.model.half() | |||
| self.model.cuda() | |||
| super().__init__(model=model, **kwargs) | |||
| def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | |||
| return inputs | |||
| # define the forward pass | |||
| def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | |||
| # check input format | |||
| for para in ['prompt', 'language']: | |||
| if para not in inputs: | |||
| raise Exception('Please check your input format.') | |||
| if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||
| raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||
| return self.model(inputs) | |||
| # format the outputs from pipeline | |||
| def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||
| return input | |||
| @@ -38,6 +38,12 @@ class CodeGeeXCodeTranslationPipeline(Pipeline): | |||
| for para in ['prompt', 'source language', 'target language']: | |||
| if para not in inputs: | |||
| raise Exception('please check your input format.') | |||
| if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||
| raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||
| if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||
| raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||
| return self.model(inputs) | |||
| # format the outputs from pipeline | |||
| @@ -121,6 +121,7 @@ class NLPTasks(object): | |||
| text_summarization = 'text-summarization' | |||
| question_answering = 'question-answering' | |||
| code_translation = 'code-translation' | |||
| code_generation = 'code-generation' | |||
| zero_shot_classification = 'zero-shot-classification' | |||
| backbone = 'backbone' | |||
| text_error_correction = 'text-error-correction' | |||