| @@ -257,6 +257,7 @@ class Pipelines(object): | |||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| mglm_text_summarization = 'mglm-text-summarization' | mglm_text_summarization = 'mglm-text-summarization' | ||||
| codegeex_code_translation = 'codegeex-code-translation' | codegeex_code_translation = 'codegeex-code-translation' | ||||
| codegeex_code_generation = 'codegeex-code-generation' | |||||
| translation_en_to_de = 'translation_en_to_de' # keep it underscore | translation_en_to_de = 'translation_en_to_de' # keep it underscore | ||||
| translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | translation_en_to_ro = 'translation_en_to_ro' # keep it underscore | ||||
| translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | translation_en_to_fr = 'translation_en_to_fr' # keep it underscore | ||||
| @@ -384,7 +385,6 @@ class Preprocessors(object): | |||||
| document_segmentation = 'document-segmentation' | document_segmentation = 'document-segmentation' | ||||
| feature_extraction = 'feature-extraction' | feature_extraction = 'feature-extraction' | ||||
| mglm_summarization = 'mglm-summarization' | mglm_summarization = 'mglm-summarization' | ||||
| codegeex = 'codegeex' | |||||
| sentence_piece = 'sentence-piece' | sentence_piece = 'sentence-piece' | ||||
| # audio preprocessor | # audio preprocessor | ||||
| @@ -36,7 +36,7 @@ if TYPE_CHECKING: | |||||
| ) | ) | ||||
| from .T5 import T5ForConditionalGeneration | from .T5 import T5ForConditionalGeneration | ||||
| from .mglm import MGLMForTextSummarization | from .mglm import MGLMForTextSummarization | ||||
| from .codegeex import CodeGeeXForCodeTranslation | |||||
| from .codegeex import CodeGeeXForCodeTranslation, CodeGeeXForCodeGeneration | |||||
| from .task_models import ( | from .task_models import ( | ||||
| FeatureExtractionModel, | FeatureExtractionModel, | ||||
| InformationExtractionModel, | InformationExtractionModel, | ||||
| @@ -109,7 +109,7 @@ else: | |||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'T5': ['T5ForConditionalGeneration'], | 'T5': ['T5ForConditionalGeneration'], | ||||
| 'mglm': ['MGLMForTextSummarization'], | 'mglm': ['MGLMForTextSummarization'], | ||||
| 'codegeex': ['CodeGeeXForCodeTranslation'], | |||||
| 'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | |||||
| 'gpt_neo': ['GPTNeoModel'], | 'gpt_neo': ['GPTNeoModel'], | ||||
| 'bloom': ['BloomModel'], | 'bloom': ['BloomModel'], | ||||
| } | } | ||||
| @@ -6,9 +6,11 @@ from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||
| from .codegeex_for_code_translation import CodeGeeXForCodeTranslation | from .codegeex_for_code_translation import CodeGeeXForCodeTranslation | ||||
| from .codegeex_for_code_generation import CodeGeeXForCodeGeneration | |||||
| else: | else: | ||||
| _import_structure = { | _import_structure = { | ||||
| 'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'], | 'codegeex_for_code_translation': ['CodeGeeXForCodeTranslation'], | ||||
| 'codegeex_for_code_generation': ['CodeGeeXForCodeGeneration'], | |||||
| } | } | ||||
| import sys | import sys | ||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| import copy | |||||
| from typing import Any, Dict | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models.base import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.logger import get_logger | |||||
| from .codegeex import CodeGeeXModel | |||||
| from .inference import get_token_stream | |||||
| from .tokenizer import CodeGeeXTokenizer | |||||
| def model_provider(): | |||||
| """Build the model.""" | |||||
| hidden_size = 5120 | |||||
| num_attention_heads = 40 | |||||
| num_layers = 39 | |||||
| padded_vocab_size = 52224 | |||||
| max_position_embeddings = 2048 | |||||
| model = CodeGeeXModel(hidden_size, num_layers, num_attention_heads, | |||||
| padded_vocab_size, max_position_embeddings) | |||||
| return model | |||||
| @MODELS.register_module(Tasks.code_generation, module_name=Models.codegeex) | |||||
| class CodeGeeXForCodeGeneration(TorchModel): | |||||
| def __init__(self, model_dir: str, *args, **kwargs): | |||||
| """initialize the fast poem model from the `model_dir` path. | |||||
| Args: | |||||
| model_dir (str): the model path. | |||||
| """ | |||||
| super().__init__(model_dir, *args, **kwargs) | |||||
| logger = get_logger() | |||||
| # loading tokenizer | |||||
| logger.info('Loading tokenizer ...') | |||||
| self.tokenizer = CodeGeeXTokenizer( | |||||
| tokenizer_path=model_dir + '/tokenizer', mode='codegeex-13b') | |||||
| # loading model | |||||
| state_dict_path = model_dir + '/ckpt_ms_213000_fp32_52224.pt' | |||||
| logger.info('Loading state dict ...') | |||||
| state_dict = torch.load(state_dict_path, map_location='cpu') | |||||
| state_dict = state_dict['module'] | |||||
| logger.info('Building CodeGeeX model ...') | |||||
| self.model = model_provider() | |||||
| self.model.load_state_dict(state_dict) | |||||
| self.model.eval() | |||||
| self.model.half() | |||||
| self.model.cuda() | |||||
| def forward(self, input: Dict[str, str]) -> Dict[str, str]: | |||||
| micro_batch_size = 1 | |||||
| seq_length = 2048 | |||||
| out_seq_length = 256 | |||||
| bad_ids = None | |||||
| lang = input['language'] | |||||
| prompt = input['prompt'] | |||||
| prompt = f"# language: {lang}\n{prompt}" | |||||
| logger = get_logger() | |||||
| tokenizer = self.tokenizer | |||||
| model = self.model | |||||
| for prompt in [prompt]: | |||||
| tokens = tokenizer.encode_code(prompt) | |||||
| n_token_prompt = len(tokens) | |||||
| token_stream = get_token_stream( | |||||
| model, | |||||
| tokenizer, | |||||
| seq_length, | |||||
| out_seq_length, | |||||
| [copy.deepcopy(tokens) for _ in range(micro_batch_size)], | |||||
| micro_batch_size=micro_batch_size, | |||||
| bad_ids=bad_ids, | |||||
| topk=1, | |||||
| topp=0.9, | |||||
| temperature=0.9, | |||||
| greedy=True | |||||
| ) | |||||
| is_finished = [False for _ in range(micro_batch_size)] | |||||
| for i, generated in enumerate(token_stream): | |||||
| generated_tokens = generated[0] | |||||
| for j in range(micro_batch_size): | |||||
| if is_finished[j]: | |||||
| continue | |||||
| if generated_tokens[j].cpu().numpy( | |||||
| )[-1] == tokenizer.eos_token_id or len( | |||||
| generated_tokens[j]) >= out_seq_length: | |||||
| is_finished[j] = True | |||||
| generated_tokens_ = generated_tokens[j].cpu().numpy( | |||||
| ).tolist() | |||||
| generated_code = tokenizer.decode_code( | |||||
| generated_tokens_[n_token_prompt:]) | |||||
| generated_code = ''.join(generated_code) | |||||
| logger.info( | |||||
| '================================= Generated code:' | |||||
| ) | |||||
| logger.info(generated_code) | |||||
| if all(is_finished): | |||||
| break | |||||
| logger.info('Generation finished.') | |||||
| return {OutputKeys.TEXT: generated_code} | |||||
| @@ -33,6 +33,7 @@ if TYPE_CHECKING: | |||||
| from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | from .zero_shot_classification_pipeline import ZeroShotClassificationPipeline | ||||
| from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | from .mglm_text_summarization_pipeline import MGLMTextSummarizationPipeline | ||||
| from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | from .codegeex_code_translation_pipeline import CodeGeeXCodeTranslationPipeline | ||||
| from .codegeex_code_generation_pipeline import CodeGeeXCodeGenerationPipeline | |||||
| from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | from .multilingual_word_segmentation_pipeline import MultilingualWordSegmentationPipeline, \ | ||||
| WordSegmentationThaiPipeline | WordSegmentationThaiPipeline | ||||
| @@ -76,6 +77,8 @@ else: | |||||
| 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | 'mglm_text_summarization_pipeline': ['MGLMTextSummarizationPipeline'], | ||||
| 'codegeex_code_translation_pipeline': | 'codegeex_code_translation_pipeline': | ||||
| ['CodeGeeXCodeTranslationPipeline'], | ['CodeGeeXCodeTranslationPipeline'], | ||||
| 'codegeex_code_generation_pipeline': | |||||
| ['CodeGeeXCodeGenerationPipeline'], | |||||
| 'multilingual_word_segmentation_pipeline': [ | 'multilingual_word_segmentation_pipeline': [ | ||||
| 'MultilingualWordSegmentationPipeline', | 'MultilingualWordSegmentationPipeline', | ||||
| 'WordSegmentationThaiPipeline' | 'WordSegmentationThaiPipeline' | ||||
| @@ -0,0 +1,48 @@ | |||||
| # Copyright (c) 2022 Zhipu.AI | |||||
| from typing import Any, Dict, Union | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.nlp import CodeGeeXForCodeGeneration | |||||
| from modelscope.pipelines.base import Pipeline | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import Preprocessor | |||||
| from modelscope.utils.constant import Tasks | |||||
| @PIPELINES.register_module( | |||||
| group_key=Tasks.code_generation, | |||||
| module_name=Pipelines.codegeex_code_generation) | |||||
| class CodeGeeXCodeGenerationPipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[CodeGeeXForCodeGeneration, str], | |||||
| preprocessor: [Preprocessor] = None, | |||||
| *args, | |||||
| **kwargs): | |||||
| model = CodeGeeXForCodeGeneration(model) if isinstance(model, | |||||
| str) else model | |||||
| self.model = model | |||||
| self.model.eval() | |||||
| self.model.half() | |||||
| self.model.cuda() | |||||
| super().__init__(model=model, **kwargs) | |||||
| def preprocess(self, inputs, **preprocess_params) -> Dict[str, Any]: | |||||
| return inputs | |||||
| # define the forward pass | |||||
| def forward(self, inputs: Union[Dict], **forward_params) -> Dict[str, Any]: | |||||
| # check input format | |||||
| for para in ['prompt', 'language']: | |||||
| if para not in inputs: | |||||
| raise Exception('Please check your input format.') | |||||
| if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||||
| raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||||
| return self.model(inputs) | |||||
| # format the outputs from pipeline | |||||
| def postprocess(self, input, **kwargs) -> Dict[str, Any]: | |||||
| return input | |||||
| @@ -38,6 +38,12 @@ class CodeGeeXCodeTranslationPipeline(Pipeline): | |||||
| for para in ['prompt', 'source language', 'target language']: | for para in ['prompt', 'source language', 'target language']: | ||||
| if para not in inputs: | if para not in inputs: | ||||
| raise Exception('please check your input format.') | raise Exception('please check your input format.') | ||||
| if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||||
| raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||||
| if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||||
| raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||||
| return self.model(inputs) | return self.model(inputs) | ||||
| # format the outputs from pipeline | # format the outputs from pipeline | ||||
| @@ -121,6 +121,7 @@ class NLPTasks(object): | |||||
| text_summarization = 'text-summarization' | text_summarization = 'text-summarization' | ||||
| question_answering = 'question-answering' | question_answering = 'question-answering' | ||||
| code_translation = 'code-translation' | code_translation = 'code-translation' | ||||
| code_generation = 'code-generation' | |||||
| zero_shot_classification = 'zero-shot-classification' | zero_shot_classification = 'zero-shot-classification' | ||||
| backbone = 'backbone' | backbone = 'backbone' | ||||
| text_error_correction = 'text-error-correction' | text_error_correction = 'text-error-correction' | ||||