| @@ -109,7 +109,8 @@ else: | |||||
| 'sentence_embedding': ['SentenceEmbedding'], | 'sentence_embedding': ['SentenceEmbedding'], | ||||
| 'T5': ['T5ForConditionalGeneration'], | 'T5': ['T5ForConditionalGeneration'], | ||||
| 'mglm': ['MGLMForTextSummarization'], | 'mglm': ['MGLMForTextSummarization'], | ||||
| 'codegeex': ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | |||||
| 'codegeex': | |||||
| ['CodeGeeXForCodeTranslation', 'CodeGeeXForCodeGeneration'], | |||||
| 'gpt_neo': ['GPTNeoModel'], | 'gpt_neo': ['GPTNeoModel'], | ||||
| 'bloom': ['BloomModel'], | 'bloom': ['BloomModel'], | ||||
| } | } | ||||
| @@ -65,7 +65,7 @@ class CodeGeeXForCodeGeneration(TorchModel): | |||||
| bad_ids = None | bad_ids = None | ||||
| lang = input['language'] | lang = input['language'] | ||||
| prompt = input['prompt'] | prompt = input['prompt'] | ||||
| prompt = f"# language: {lang}\n{prompt}" | |||||
| prompt = f'# language: {lang}\n{prompt}' | |||||
| logger = get_logger() | logger = get_logger() | ||||
| tokenizer = self.tokenizer | tokenizer = self.tokenizer | ||||
| model = self.model | model = self.model | ||||
| @@ -83,8 +83,7 @@ class CodeGeeXForCodeGeneration(TorchModel): | |||||
| topk=1, | topk=1, | ||||
| topp=0.9, | topp=0.9, | ||||
| temperature=0.9, | temperature=0.9, | ||||
| greedy=True | |||||
| ) | |||||
| greedy=True) | |||||
| is_finished = [False for _ in range(micro_batch_size)] | is_finished = [False for _ in range(micro_batch_size)] | ||||
| for i, generated in enumerate(token_stream): | for i, generated in enumerate(token_stream): | ||||
| generated_tokens = generated[0] | generated_tokens = generated[0] | ||||
| @@ -21,7 +21,7 @@ class CodeGeeXCodeGenerationPipeline(Pipeline): | |||||
| *args, | *args, | ||||
| **kwargs): | **kwargs): | ||||
| model = CodeGeeXForCodeGeneration(model) if isinstance(model, | model = CodeGeeXForCodeGeneration(model) if isinstance(model, | ||||
| str) else model | |||||
| str) else model | |||||
| self.model = model | self.model = model | ||||
| self.model.eval() | self.model.eval() | ||||
| self.model.half() | self.model.half() | ||||
| @@ -38,8 +38,15 @@ class CodeGeeXCodeGenerationPipeline(Pipeline): | |||||
| for para in ['prompt', 'language']: | for para in ['prompt', 'language']: | ||||
| if para not in inputs: | if para not in inputs: | ||||
| raise Exception('Please check your input format.') | raise Exception('Please check your input format.') | ||||
| if inputs['language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||||
| raise Exception('Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||||
| if inputs['language'] not in [ | |||||
| 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||||
| 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||||
| 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||||
| 'Pascal', 'R', 'Fortran', 'Lean' | |||||
| ]: # noqa | |||||
| raise Exception( | |||||
| 'Make sure the language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||||
| ) # noqa | |||||
| return self.model(inputs) | return self.model(inputs) | ||||
| @@ -38,11 +38,25 @@ class CodeGeeXCodeTranslationPipeline(Pipeline): | |||||
| for para in ['prompt', 'source language', 'target language']: | for para in ['prompt', 'source language', 'target language']: | ||||
| if para not in inputs: | if para not in inputs: | ||||
| raise Exception('please check your input format.') | raise Exception('please check your input format.') | ||||
| if inputs['source language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||||
| raise Exception('Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||||
| if inputs['source language'] not in [ | |||||
| 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||||
| 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||||
| 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||||
| 'Pascal', 'R', 'Fortran', 'Lean' | |||||
| ]: | |||||
| raise Exception( | |||||
| 'Make sure the source language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||||
| ) # noqa | |||||
| if inputs['target language'] not in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]: # noqa | |||||
| raise Exception('Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]') # noqa | |||||
| if inputs['target language'] not in [ | |||||
| 'C++', 'C', 'C#', 'Cuda', 'Objective-C', 'Objective-C++', | |||||
| 'Python', 'Java', 'Scala', 'TeX', 'HTML', 'PHP', 'JavaScript', | |||||
| 'TypeScript', 'Go', 'Shell', 'Rust', 'CSS', 'SQL', 'Kotlin', | |||||
| 'Pascal', 'R', 'Fortran', 'Lean' | |||||
| ]: | |||||
| raise Exception( | |||||
| 'Make sure the target language is in ["C++","C","C#","Cuda","Objective-C","Objective-C++","Python","Java","Scala","TeX","HTML","PHP","JavaScript","TypeScript","Go","Shell","Rust","CSS","SQL","Kotlin","Pascal","R","Fortran","Lean"]' # noqa | |||||
| ) # noqa | |||||
| return self.model(inputs) | return self.model(inputs) | ||||