Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10481410master
| @@ -99,6 +99,10 @@ class Models(object): | |||||
| team = 'team-multi-modal-similarity' | team = 'team-multi-modal-similarity' | ||||
| video_clip = 'video-clip-multi-modal-embedding' | video_clip = 'video-clip-multi-modal-embedding' | ||||
| # science models | |||||
| unifold = 'unifold' | |||||
| unifold_symmetry = 'unifold-symmetry' | |||||
| class TaskModels(object): | class TaskModels(object): | ||||
| # nlp task | # nlp task | ||||
| @@ -266,6 +270,9 @@ class Pipelines(object): | |||||
| image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
| ofa_ocr_recognition = 'ofa-ocr-recognition' | ofa_ocr_recognition = 'ofa-ocr-recognition' | ||||
| # science tasks | |||||
| protein_structure = 'unifold-protein-structure' | |||||
| class Trainers(object): | class Trainers(object): | ||||
| """ Names for different trainer. | """ Names for different trainer. | ||||
| @@ -368,6 +375,9 @@ class Preprocessors(object): | |||||
| ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ofa_tasks_preprocessor = 'ofa-tasks-preprocessor' | ||||
| mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | mplug_tasks_preprocessor = 'mplug-tasks-preprocessor' | ||||
| # science preprocessor | |||||
| unifold_preprocessor = 'unifold-preprocessor' | |||||
| class Metrics(object): | class Metrics(object): | ||||
| """ Names for different metrics. | """ Names for different metrics. | ||||
| @@ -0,0 +1,21 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .unifold import UnifoldForProteinStructrue | |||||
| else: | |||||
| _import_structure = {'unifold': ['UnifoldForProteinStructrue']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1 @@ | |||||
| from .model import UnifoldForProteinStructrue | |||||
| @@ -0,0 +1,636 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| import copy | |||||
| from typing import Any | |||||
| import ml_collections as mlc | |||||
| N_RES = 'number of residues' | |||||
| N_MSA = 'number of MSA sequences' | |||||
| N_EXTRA_MSA = 'number of extra MSA sequences' | |||||
| N_TPL = 'number of templates' | |||||
| d_pair = mlc.FieldReference(128, field_type=int) | |||||
| d_msa = mlc.FieldReference(256, field_type=int) | |||||
| d_template = mlc.FieldReference(64, field_type=int) | |||||
| d_extra_msa = mlc.FieldReference(64, field_type=int) | |||||
| d_single = mlc.FieldReference(384, field_type=int) | |||||
| max_recycling_iters = mlc.FieldReference(3, field_type=int) | |||||
| chunk_size = mlc.FieldReference(4, field_type=int) | |||||
| aux_distogram_bins = mlc.FieldReference(64, field_type=int) | |||||
| eps = mlc.FieldReference(1e-8, field_type=float) | |||||
| inf = mlc.FieldReference(3e4, field_type=float) | |||||
| use_templates = mlc.FieldReference(True, field_type=bool) | |||||
| is_multimer = mlc.FieldReference(False, field_type=bool) | |||||
| def base_config(): | |||||
| return mlc.ConfigDict({ | |||||
| 'data': { | |||||
| 'common': { | |||||
| 'features': { | |||||
| 'aatype': [N_RES], | |||||
| 'all_atom_mask': [N_RES, None], | |||||
| 'all_atom_positions': [N_RES, None, None], | |||||
| 'alt_chi_angles': [N_RES, None], | |||||
| 'atom14_alt_gt_exists': [N_RES, None], | |||||
| 'atom14_alt_gt_positions': [N_RES, None, None], | |||||
| 'atom14_atom_exists': [N_RES, None], | |||||
| 'atom14_atom_is_ambiguous': [N_RES, None], | |||||
| 'atom14_gt_exists': [N_RES, None], | |||||
| 'atom14_gt_positions': [N_RES, None, None], | |||||
| 'atom37_atom_exists': [N_RES, None], | |||||
| 'frame_mask': [N_RES], | |||||
| 'true_frame_tensor': [N_RES, None, None], | |||||
| 'bert_mask': [N_MSA, N_RES], | |||||
| 'chi_angles_sin_cos': [N_RES, None, None], | |||||
| 'chi_mask': [N_RES, None], | |||||
| 'extra_msa_deletion_value': [N_EXTRA_MSA, N_RES], | |||||
| 'extra_msa_has_deletion': [N_EXTRA_MSA, N_RES], | |||||
| 'extra_msa': [N_EXTRA_MSA, N_RES], | |||||
| 'extra_msa_mask': [N_EXTRA_MSA, N_RES], | |||||
| 'extra_msa_row_mask': [N_EXTRA_MSA], | |||||
| 'is_distillation': [], | |||||
| 'msa_feat': [N_MSA, N_RES, None], | |||||
| 'msa_mask': [N_MSA, N_RES], | |||||
| 'msa_chains': [N_MSA, None], | |||||
| 'msa_row_mask': [N_MSA], | |||||
| 'num_recycling_iters': [], | |||||
| 'pseudo_beta': [N_RES, None], | |||||
| 'pseudo_beta_mask': [N_RES], | |||||
| 'residue_index': [N_RES], | |||||
| 'residx_atom14_to_atom37': [N_RES, None], | |||||
| 'residx_atom37_to_atom14': [N_RES, None], | |||||
| 'resolution': [], | |||||
| 'rigidgroups_alt_gt_frames': [N_RES, None, None, None], | |||||
| 'rigidgroups_group_exists': [N_RES, None], | |||||
| 'rigidgroups_group_is_ambiguous': [N_RES, None], | |||||
| 'rigidgroups_gt_exists': [N_RES, None], | |||||
| 'rigidgroups_gt_frames': [N_RES, None, None, None], | |||||
| 'seq_length': [], | |||||
| 'seq_mask': [N_RES], | |||||
| 'target_feat': [N_RES, None], | |||||
| 'template_aatype': [N_TPL, N_RES], | |||||
| 'template_all_atom_mask': [N_TPL, N_RES, None], | |||||
| 'template_all_atom_positions': [N_TPL, N_RES, None, None], | |||||
| 'template_alt_torsion_angles_sin_cos': [ | |||||
| N_TPL, | |||||
| N_RES, | |||||
| None, | |||||
| None, | |||||
| ], | |||||
| 'template_frame_mask': [N_TPL, N_RES], | |||||
| 'template_frame_tensor': [N_TPL, N_RES, None, None], | |||||
| 'template_mask': [N_TPL], | |||||
| 'template_pseudo_beta': [N_TPL, N_RES, None], | |||||
| 'template_pseudo_beta_mask': [N_TPL, N_RES], | |||||
| 'template_sum_probs': [N_TPL, None], | |||||
| 'template_torsion_angles_mask': [N_TPL, N_RES, None], | |||||
| 'template_torsion_angles_sin_cos': | |||||
| [N_TPL, N_RES, None, None], | |||||
| 'true_msa': [N_MSA, N_RES], | |||||
| 'use_clamped_fape': [], | |||||
| 'assembly_num_chains': [1], | |||||
| 'asym_id': [N_RES], | |||||
| 'sym_id': [N_RES], | |||||
| 'entity_id': [N_RES], | |||||
| 'num_sym': [N_RES], | |||||
| 'asym_len': [None], | |||||
| 'cluster_bias_mask': [N_MSA], | |||||
| }, | |||||
| 'masked_msa': { | |||||
| 'profile_prob': 0.1, | |||||
| 'same_prob': 0.1, | |||||
| 'uniform_prob': 0.1, | |||||
| }, | |||||
| 'block_delete_msa': { | |||||
| 'msa_fraction_per_block': 0.3, | |||||
| 'randomize_num_blocks': False, | |||||
| 'num_blocks': 5, | |||||
| 'min_num_msa': 16, | |||||
| }, | |||||
| 'random_delete_msa': { | |||||
| 'max_msa_entry': 1 << 25, # := 33554432 | |||||
| }, | |||||
| 'v2_feature': | |||||
| False, | |||||
| 'gumbel_sample': | |||||
| False, | |||||
| 'max_extra_msa': | |||||
| 1024, | |||||
| 'msa_cluster_features': | |||||
| True, | |||||
| 'reduce_msa_clusters_by_max_templates': | |||||
| True, | |||||
| 'resample_msa_in_recycling': | |||||
| True, | |||||
| 'template_features': [ | |||||
| 'template_all_atom_positions', | |||||
| 'template_sum_probs', | |||||
| 'template_aatype', | |||||
| 'template_all_atom_mask', | |||||
| ], | |||||
| 'unsupervised_features': [ | |||||
| 'aatype', | |||||
| 'residue_index', | |||||
| 'msa', | |||||
| 'msa_chains', | |||||
| 'num_alignments', | |||||
| 'seq_length', | |||||
| 'between_segment_residues', | |||||
| 'deletion_matrix', | |||||
| 'num_recycling_iters', | |||||
| 'crop_and_fix_size_seed', | |||||
| ], | |||||
| 'recycling_features': [ | |||||
| 'msa_chains', | |||||
| 'msa_mask', | |||||
| 'msa_row_mask', | |||||
| 'bert_mask', | |||||
| 'true_msa', | |||||
| 'msa_feat', | |||||
| 'extra_msa_deletion_value', | |||||
| 'extra_msa_has_deletion', | |||||
| 'extra_msa', | |||||
| 'extra_msa_mask', | |||||
| 'extra_msa_row_mask', | |||||
| 'is_distillation', | |||||
| ], | |||||
| 'multimer_features': [ | |||||
| 'assembly_num_chains', | |||||
| 'asym_id', | |||||
| 'sym_id', | |||||
| 'num_sym', | |||||
| 'entity_id', | |||||
| 'asym_len', | |||||
| 'cluster_bias_mask', | |||||
| ], | |||||
| 'use_templates': | |||||
| use_templates, | |||||
| 'is_multimer': | |||||
| is_multimer, | |||||
| 'use_template_torsion_angles': | |||||
| use_templates, | |||||
| 'max_recycling_iters': | |||||
| max_recycling_iters, | |||||
| }, | |||||
| 'supervised': { | |||||
| 'use_clamped_fape_prob': | |||||
| 1.0, | |||||
| 'supervised_features': [ | |||||
| 'all_atom_mask', | |||||
| 'all_atom_positions', | |||||
| 'resolution', | |||||
| 'use_clamped_fape', | |||||
| 'is_distillation', | |||||
| ], | |||||
| }, | |||||
| 'predict': { | |||||
| 'fixed_size': True, | |||||
| 'subsample_templates': False, | |||||
| 'block_delete_msa': False, | |||||
| 'random_delete_msa': True, | |||||
| 'masked_msa_replace_fraction': 0.15, | |||||
| 'max_msa_clusters': 128, | |||||
| 'max_templates': 4, | |||||
| 'num_ensembles': 2, | |||||
| 'crop': False, | |||||
| 'crop_size': None, | |||||
| 'supervised': False, | |||||
| 'biased_msa_by_chain': False, | |||||
| 'share_mask': False, | |||||
| }, | |||||
| 'eval': { | |||||
| 'fixed_size': True, | |||||
| 'subsample_templates': False, | |||||
| 'block_delete_msa': False, | |||||
| 'random_delete_msa': True, | |||||
| 'masked_msa_replace_fraction': 0.15, | |||||
| 'max_msa_clusters': 128, | |||||
| 'max_templates': 4, | |||||
| 'num_ensembles': 1, | |||||
| 'crop': False, | |||||
| 'crop_size': None, | |||||
| 'spatial_crop_prob': 0.5, | |||||
| 'ca_ca_threshold': 10.0, | |||||
| 'supervised': True, | |||||
| 'biased_msa_by_chain': False, | |||||
| 'share_mask': False, | |||||
| }, | |||||
| 'train': { | |||||
| 'fixed_size': True, | |||||
| 'subsample_templates': True, | |||||
| 'block_delete_msa': True, | |||||
| 'random_delete_msa': True, | |||||
| 'masked_msa_replace_fraction': 0.15, | |||||
| 'max_msa_clusters': 128, | |||||
| 'max_templates': 4, | |||||
| 'num_ensembles': 1, | |||||
| 'crop': True, | |||||
| 'crop_size': 256, | |||||
| 'spatial_crop_prob': 0.5, | |||||
| 'ca_ca_threshold': 10.0, | |||||
| 'supervised': True, | |||||
| 'use_clamped_fape_prob': 1.0, | |||||
| 'max_distillation_msa_clusters': 1000, | |||||
| 'biased_msa_by_chain': True, | |||||
| 'share_mask': True, | |||||
| }, | |||||
| }, | |||||
| 'globals': { | |||||
| 'chunk_size': chunk_size, | |||||
| 'block_size': None, | |||||
| 'd_pair': d_pair, | |||||
| 'd_msa': d_msa, | |||||
| 'd_template': d_template, | |||||
| 'd_extra_msa': d_extra_msa, | |||||
| 'd_single': d_single, | |||||
| 'eps': eps, | |||||
| 'inf': inf, | |||||
| 'max_recycling_iters': max_recycling_iters, | |||||
| 'alphafold_original_mode': False, | |||||
| }, | |||||
| 'model': { | |||||
| 'is_multimer': is_multimer, | |||||
| 'input_embedder': { | |||||
| 'tf_dim': 22, | |||||
| 'msa_dim': 49, | |||||
| 'd_pair': d_pair, | |||||
| 'd_msa': d_msa, | |||||
| 'relpos_k': 32, | |||||
| 'max_relative_chain': 2, | |||||
| }, | |||||
| 'recycling_embedder': { | |||||
| 'd_pair': d_pair, | |||||
| 'd_msa': d_msa, | |||||
| 'min_bin': 3.25, | |||||
| 'max_bin': 20.75, | |||||
| 'num_bins': 15, | |||||
| 'inf': 1e8, | |||||
| }, | |||||
| 'template': { | |||||
| 'distogram': { | |||||
| 'min_bin': 3.25, | |||||
| 'max_bin': 50.75, | |||||
| 'num_bins': 39, | |||||
| }, | |||||
| 'template_angle_embedder': { | |||||
| 'd_in': 57, | |||||
| 'd_out': d_msa, | |||||
| }, | |||||
| 'template_pair_embedder': { | |||||
| 'd_in': 88, | |||||
| 'v2_d_in': [39, 1, 22, 22, 1, 1, 1, 1], | |||||
| 'd_pair': d_pair, | |||||
| 'd_out': d_template, | |||||
| 'v2_feature': False, | |||||
| }, | |||||
| 'template_pair_stack': { | |||||
| 'd_template': d_template, | |||||
| 'd_hid_tri_att': 16, | |||||
| 'd_hid_tri_mul': 64, | |||||
| 'num_blocks': 2, | |||||
| 'num_heads': 4, | |||||
| 'pair_transition_n': 2, | |||||
| 'dropout_rate': 0.25, | |||||
| 'inf': 1e9, | |||||
| 'tri_attn_first': True, | |||||
| }, | |||||
| 'template_pointwise_attention': { | |||||
| 'enabled': True, | |||||
| 'd_template': d_template, | |||||
| 'd_pair': d_pair, | |||||
| 'd_hid': 16, | |||||
| 'num_heads': 4, | |||||
| 'inf': 1e5, | |||||
| }, | |||||
| 'inf': 1e5, | |||||
| 'eps': 1e-6, | |||||
| 'enabled': use_templates, | |||||
| 'embed_angles': use_templates, | |||||
| }, | |||||
| 'extra_msa': { | |||||
| 'extra_msa_embedder': { | |||||
| 'd_in': 25, | |||||
| 'd_out': d_extra_msa, | |||||
| }, | |||||
| 'extra_msa_stack': { | |||||
| 'd_msa': d_extra_msa, | |||||
| 'd_pair': d_pair, | |||||
| 'd_hid_msa_att': 8, | |||||
| 'd_hid_opm': 32, | |||||
| 'd_hid_mul': 128, | |||||
| 'd_hid_pair_att': 32, | |||||
| 'num_heads_msa': 8, | |||||
| 'num_heads_pair': 4, | |||||
| 'num_blocks': 4, | |||||
| 'transition_n': 4, | |||||
| 'msa_dropout': 0.15, | |||||
| 'pair_dropout': 0.25, | |||||
| 'inf': 1e9, | |||||
| 'eps': 1e-10, | |||||
| 'outer_product_mean_first': False, | |||||
| }, | |||||
| 'enabled': True, | |||||
| }, | |||||
| 'evoformer_stack': { | |||||
| 'd_msa': d_msa, | |||||
| 'd_pair': d_pair, | |||||
| 'd_hid_msa_att': 32, | |||||
| 'd_hid_opm': 32, | |||||
| 'd_hid_mul': 128, | |||||
| 'd_hid_pair_att': 32, | |||||
| 'd_single': d_single, | |||||
| 'num_heads_msa': 8, | |||||
| 'num_heads_pair': 4, | |||||
| 'num_blocks': 48, | |||||
| 'transition_n': 4, | |||||
| 'msa_dropout': 0.15, | |||||
| 'pair_dropout': 0.25, | |||||
| 'inf': 1e9, | |||||
| 'eps': 1e-10, | |||||
| 'outer_product_mean_first': False, | |||||
| }, | |||||
| 'structure_module': { | |||||
| 'd_single': d_single, | |||||
| 'd_pair': d_pair, | |||||
| 'd_ipa': 16, | |||||
| 'd_angle': 128, | |||||
| 'num_heads_ipa': 12, | |||||
| 'num_qk_points': 4, | |||||
| 'num_v_points': 8, | |||||
| 'dropout_rate': 0.1, | |||||
| 'num_blocks': 8, | |||||
| 'no_transition_layers': 1, | |||||
| 'num_resnet_blocks': 2, | |||||
| 'num_angles': 7, | |||||
| 'trans_scale_factor': 10, | |||||
| 'epsilon': 1e-12, | |||||
| 'inf': 1e5, | |||||
| 'separate_kv': False, | |||||
| 'ipa_bias': True, | |||||
| }, | |||||
| 'heads': { | |||||
| 'plddt': { | |||||
| 'num_bins': 50, | |||||
| 'd_in': d_single, | |||||
| 'd_hid': 128, | |||||
| }, | |||||
| 'distogram': { | |||||
| 'd_pair': d_pair, | |||||
| 'num_bins': aux_distogram_bins, | |||||
| 'disable_enhance_head': False, | |||||
| }, | |||||
| 'pae': { | |||||
| 'd_pair': d_pair, | |||||
| 'num_bins': aux_distogram_bins, | |||||
| 'enabled': False, | |||||
| 'iptm_weight': 0.8, | |||||
| 'disable_enhance_head': False, | |||||
| }, | |||||
| 'masked_msa': { | |||||
| 'd_msa': d_msa, | |||||
| 'd_out': 23, | |||||
| 'disable_enhance_head': False, | |||||
| }, | |||||
| 'experimentally_resolved': { | |||||
| 'd_single': d_single, | |||||
| 'd_out': 37, | |||||
| 'enabled': False, | |||||
| 'disable_enhance_head': False, | |||||
| }, | |||||
| }, | |||||
| }, | |||||
| 'loss': { | |||||
| 'distogram': { | |||||
| 'min_bin': 2.3125, | |||||
| 'max_bin': 21.6875, | |||||
| 'num_bins': 64, | |||||
| 'eps': 1e-6, | |||||
| 'weight': 0.3, | |||||
| }, | |||||
| 'experimentally_resolved': { | |||||
| 'eps': 1e-8, | |||||
| 'min_resolution': 0.1, | |||||
| 'max_resolution': 3.0, | |||||
| 'weight': 0.0, | |||||
| }, | |||||
| 'fape': { | |||||
| 'backbone': { | |||||
| 'clamp_distance': 10.0, | |||||
| 'clamp_distance_between_chains': 30.0, | |||||
| 'loss_unit_distance': 10.0, | |||||
| 'loss_unit_distance_between_chains': 20.0, | |||||
| 'weight': 0.5, | |||||
| 'eps': 1e-4, | |||||
| }, | |||||
| 'sidechain': { | |||||
| 'clamp_distance': 10.0, | |||||
| 'length_scale': 10.0, | |||||
| 'weight': 0.5, | |||||
| 'eps': 1e-4, | |||||
| }, | |||||
| 'weight': 1.0, | |||||
| }, | |||||
| 'plddt': { | |||||
| 'min_resolution': 0.1, | |||||
| 'max_resolution': 3.0, | |||||
| 'cutoff': 15.0, | |||||
| 'num_bins': 50, | |||||
| 'eps': 1e-10, | |||||
| 'weight': 0.01, | |||||
| }, | |||||
| 'masked_msa': { | |||||
| 'eps': 1e-8, | |||||
| 'weight': 2.0, | |||||
| }, | |||||
| 'supervised_chi': { | |||||
| 'chi_weight': 0.5, | |||||
| 'angle_norm_weight': 0.01, | |||||
| 'eps': 1e-6, | |||||
| 'weight': 1.0, | |||||
| }, | |||||
| 'violation': { | |||||
| 'violation_tolerance_factor': 12.0, | |||||
| 'clash_overlap_tolerance': 1.5, | |||||
| 'bond_angle_loss_weight': 0.3, | |||||
| 'eps': 1e-6, | |||||
| 'weight': 0.0, | |||||
| }, | |||||
| 'pae': { | |||||
| 'max_bin': 31, | |||||
| 'num_bins': 64, | |||||
| 'min_resolution': 0.1, | |||||
| 'max_resolution': 3.0, | |||||
| 'eps': 1e-8, | |||||
| 'weight': 0.0, | |||||
| }, | |||||
| 'repr_norm': { | |||||
| 'weight': 0.01, | |||||
| 'tolerance': 1.0, | |||||
| }, | |||||
| 'chain_centre_mass': { | |||||
| 'weight': 0.0, | |||||
| 'eps': 1e-8, | |||||
| }, | |||||
| }, | |||||
| }) | |||||
| def recursive_set(c: mlc.ConfigDict, key: str, value: Any, ignore: str = None): | |||||
| with c.unlocked(): | |||||
| for k, v in c.items(): | |||||
| if ignore is not None and k == ignore: | |||||
| continue | |||||
| if isinstance(v, mlc.ConfigDict): | |||||
| recursive_set(v, key, value) | |||||
| elif k == key: | |||||
| c[k] = value | |||||
| def model_config(name, train=False): | |||||
| c = copy.deepcopy(base_config()) | |||||
| def model_2_v2(c): | |||||
| recursive_set(c, 'v2_feature', True) | |||||
| recursive_set(c, 'gumbel_sample', True) | |||||
| c.model.heads.masked_msa.d_out = 22 | |||||
| c.model.structure_module.separate_kv = True | |||||
| c.model.structure_module.ipa_bias = False | |||||
| c.model.template.template_angle_embedder.d_in = 34 | |||||
| return c | |||||
| def multimer(c): | |||||
| recursive_set(c, 'is_multimer', True) | |||||
| recursive_set(c, 'max_extra_msa', 1152) | |||||
| recursive_set(c, 'max_msa_clusters', 128) | |||||
| recursive_set(c, 'v2_feature', True) | |||||
| recursive_set(c, 'gumbel_sample', True) | |||||
| c.model.template.template_angle_embedder.d_in = 34 | |||||
| c.model.template.template_pair_stack.tri_attn_first = False | |||||
| c.model.template.template_pointwise_attention.enabled = False | |||||
| c.model.heads.pae.enabled = True | |||||
| # we forget to enable it in our training, so disable it here | |||||
| c.model.heads.pae.disable_enhance_head = True | |||||
| c.model.heads.masked_msa.d_out = 22 | |||||
| c.model.structure_module.separate_kv = True | |||||
| c.model.structure_module.ipa_bias = False | |||||
| c.model.structure_module.trans_scale_factor = 20 | |||||
| c.loss.pae.weight = 0.1 | |||||
| c.model.input_embedder.tf_dim = 21 | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| c.loss.chain_centre_mass.weight = 1.0 | |||||
| return c | |||||
| if name == 'model_1': | |||||
| pass | |||||
| elif name == 'model_1_ft': | |||||
| recursive_set(c, 'max_extra_msa', 5120) | |||||
| recursive_set(c, 'max_msa_clusters', 512) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| elif name == 'model_1_af2': | |||||
| recursive_set(c, 'max_extra_msa', 5120) | |||||
| recursive_set(c, 'max_msa_clusters', 512) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| c.loss.repr_norm.weight = 0 | |||||
| c.model.heads.experimentally_resolved.enabled = True | |||||
| c.loss.experimentally_resolved.weight = 0.01 | |||||
| c.globals.alphafold_original_mode = True | |||||
| elif name == 'model_2': | |||||
| pass | |||||
| elif name == 'model_init': | |||||
| pass | |||||
| elif name == 'model_init_af2': | |||||
| c.globals.alphafold_original_mode = True | |||||
| pass | |||||
| elif name == 'model_2_ft': | |||||
| recursive_set(c, 'max_extra_msa', 1024) | |||||
| recursive_set(c, 'max_msa_clusters', 512) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| elif name == 'model_2_af2': | |||||
| recursive_set(c, 'max_extra_msa', 1024) | |||||
| recursive_set(c, 'max_msa_clusters', 512) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| c.loss.repr_norm.weight = 0 | |||||
| c.model.heads.experimentally_resolved.enabled = True | |||||
| c.loss.experimentally_resolved.weight = 0.01 | |||||
| c.globals.alphafold_original_mode = True | |||||
| elif name == 'model_2_v2': | |||||
| c = model_2_v2(c) | |||||
| elif name == 'model_2_v2_ft': | |||||
| c = model_2_v2(c) | |||||
| recursive_set(c, 'max_extra_msa', 1024) | |||||
| recursive_set(c, 'max_msa_clusters', 512) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| elif name == 'model_3_af2' or name == 'model_4_af2': | |||||
| recursive_set(c, 'max_extra_msa', 5120) | |||||
| recursive_set(c, 'max_msa_clusters', 512) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| c.loss.repr_norm.weight = 0 | |||||
| c.model.heads.experimentally_resolved.enabled = True | |||||
| c.loss.experimentally_resolved.weight = 0.01 | |||||
| c.globals.alphafold_original_mode = True | |||||
| c.model.template.enabled = False | |||||
| c.model.template.embed_angles = False | |||||
| recursive_set(c, 'use_templates', False) | |||||
| recursive_set(c, 'use_template_torsion_angles', False) | |||||
| elif name == 'model_5_af2': | |||||
| recursive_set(c, 'max_extra_msa', 1024) | |||||
| recursive_set(c, 'max_msa_clusters', 512) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.02 | |||||
| c.loss.repr_norm.weight = 0 | |||||
| c.model.heads.experimentally_resolved.enabled = True | |||||
| c.loss.experimentally_resolved.weight = 0.01 | |||||
| c.globals.alphafold_original_mode = True | |||||
| c.model.template.enabled = False | |||||
| c.model.template.embed_angles = False | |||||
| recursive_set(c, 'use_templates', False) | |||||
| recursive_set(c, 'use_template_torsion_angles', False) | |||||
| elif name == 'multimer': | |||||
| c = multimer(c) | |||||
| elif name == 'multimer_ft': | |||||
| c = multimer(c) | |||||
| recursive_set(c, 'max_extra_msa', 1152) | |||||
| recursive_set(c, 'max_msa_clusters', 256) | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.violation.weight = 0.5 | |||||
| elif name == 'multimer_af2': | |||||
| recursive_set(c, 'max_extra_msa', 1152) | |||||
| recursive_set(c, 'max_msa_clusters', 256) | |||||
| recursive_set(c, 'is_multimer', True) | |||||
| recursive_set(c, 'v2_feature', True) | |||||
| recursive_set(c, 'gumbel_sample', True) | |||||
| c.model.template.template_angle_embedder.d_in = 34 | |||||
| c.model.template.template_pair_stack.tri_attn_first = False | |||||
| c.model.template.template_pointwise_attention.enabled = False | |||||
| c.model.heads.pae.enabled = True | |||||
| c.model.heads.experimentally_resolved.enabled = True | |||||
| c.model.heads.masked_msa.d_out = 22 | |||||
| c.model.structure_module.separate_kv = True | |||||
| c.model.structure_module.ipa_bias = False | |||||
| c.model.structure_module.trans_scale_factor = 20 | |||||
| c.loss.pae.weight = 0.1 | |||||
| c.loss.violation.weight = 0.5 | |||||
| c.loss.experimentally_resolved.weight = 0.01 | |||||
| c.model.input_embedder.tf_dim = 21 | |||||
| c.globals.alphafold_original_mode = True | |||||
| c.data.train.crop_size = 384 | |||||
| c.loss.repr_norm.weight = 0 | |||||
| c.loss.chain_centre_mass.weight = 1.0 | |||||
| recursive_set(c, 'outer_product_mean_first', True) | |||||
| else: | |||||
| raise ValueError(f'invalid --model-name: {name}.') | |||||
| if train: | |||||
| c.globals.chunk_size = None | |||||
| recursive_set(c, 'inf', 3e4) | |||||
| recursive_set(c, 'eps', 1e-5, 'loss') | |||||
| return c | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Data pipeline for model features.""" | |||||
| @@ -0,0 +1,526 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Pairing logic for multimer data """ | |||||
| import collections | |||||
| from typing import Dict, Iterable, List, Sequence | |||||
| import numpy as np | |||||
| import pandas as pd | |||||
| import scipy.linalg | |||||
| from .data_ops import NumpyDict | |||||
| from .residue_constants import restypes_with_x_and_gap | |||||
| MSA_GAP_IDX = restypes_with_x_and_gap.index('-') | |||||
| SEQUENCE_GAP_CUTOFF = 0.5 | |||||
| SEQUENCE_SIMILARITY_CUTOFF = 0.9 | |||||
| MSA_PAD_VALUES = { | |||||
| 'msa_all_seq': MSA_GAP_IDX, | |||||
| 'msa_mask_all_seq': 1, | |||||
| 'deletion_matrix_all_seq': 0, | |||||
| 'deletion_matrix_int_all_seq': 0, | |||||
| 'msa': MSA_GAP_IDX, | |||||
| 'msa_mask': 1, | |||||
| 'deletion_matrix': 0, | |||||
| 'deletion_matrix_int': 0, | |||||
| } | |||||
| MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix', 'deletion_matrix_int') | |||||
| SEQ_FEATURES = ( | |||||
| 'residue_index', | |||||
| 'aatype', | |||||
| 'all_atom_positions', | |||||
| 'all_atom_mask', | |||||
| 'seq_mask', | |||||
| 'between_segment_residues', | |||||
| 'has_alt_locations', | |||||
| 'has_hetatoms', | |||||
| 'asym_id', | |||||
| 'entity_id', | |||||
| 'sym_id', | |||||
| 'entity_mask', | |||||
| 'deletion_mean', | |||||
| 'prediction_atom_mask', | |||||
| 'literature_positions', | |||||
| 'atom_indices_to_group_indices', | |||||
| 'rigid_group_default_frame', | |||||
| # zy | |||||
| 'num_sym', | |||||
| ) | |||||
| TEMPLATE_FEATURES = ( | |||||
| 'template_aatype', | |||||
| 'template_all_atom_positions', | |||||
| 'template_all_atom_mask', | |||||
| ) | |||||
| CHAIN_FEATURES = ('num_alignments', 'seq_length') | |||||
| def create_paired_features(chains: Iterable[NumpyDict], ) -> List[NumpyDict]: | |||||
| """Returns the original chains with paired NUM_SEQ features. | |||||
| Args: | |||||
| chains: A list of feature dictionaries for each chain. | |||||
| Returns: | |||||
| A list of feature dictionaries with sequence features including only | |||||
| rows to be paired. | |||||
| """ | |||||
| chains = list(chains) | |||||
| chain_keys = chains[0].keys() | |||||
| if len(chains) < 2: | |||||
| return chains | |||||
| else: | |||||
| updated_chains = [] | |||||
| paired_chains_to_paired_row_indices = pair_sequences(chains) | |||||
| paired_rows = reorder_paired_rows(paired_chains_to_paired_row_indices) | |||||
| for chain_num, chain in enumerate(chains): | |||||
| new_chain = {k: v for k, v in chain.items() if '_all_seq' not in k} | |||||
| for feature_name in chain_keys: | |||||
| if feature_name.endswith('_all_seq'): | |||||
| feats_padded = pad_features(chain[feature_name], | |||||
| feature_name) | |||||
| new_chain[feature_name] = feats_padded[ | |||||
| paired_rows[:, chain_num]] | |||||
| new_chain['num_alignments_all_seq'] = np.asarray( | |||||
| len(paired_rows[:, chain_num])) | |||||
| updated_chains.append(new_chain) | |||||
| return updated_chains | |||||
| def pad_features(feature: np.ndarray, feature_name: str) -> np.ndarray: | |||||
| """Add a 'padding' row at the end of the features list. | |||||
| The padding row will be selected as a 'paired' row in the case of partial | |||||
| alignment - for the chain that doesn't have paired alignment. | |||||
| Args: | |||||
| feature: The feature to be padded. | |||||
| feature_name: The name of the feature to be padded. | |||||
| Returns: | |||||
| The feature with an additional padding row. | |||||
| """ | |||||
| assert feature.dtype != np.dtype(np.string_) | |||||
| if feature_name in ( | |||||
| 'msa_all_seq', | |||||
| 'msa_mask_all_seq', | |||||
| 'deletion_matrix_all_seq', | |||||
| 'deletion_matrix_int_all_seq', | |||||
| ): | |||||
| num_res = feature.shape[1] | |||||
| padding = MSA_PAD_VALUES[feature_name] * np.ones([1, num_res], | |||||
| feature.dtype) | |||||
| elif feature_name == 'msa_species_identifiers_all_seq': | |||||
| padding = [b''] | |||||
| else: | |||||
| return feature | |||||
| feats_padded = np.concatenate([feature, padding], axis=0) | |||||
| return feats_padded | |||||
| def _make_msa_df(chain_features: NumpyDict) -> pd.DataFrame: | |||||
| """Makes dataframe with msa features needed for msa pairing.""" | |||||
| chain_msa = chain_features['msa_all_seq'] | |||||
| query_seq = chain_msa[0] | |||||
| per_seq_similarity = np.sum( | |||||
| query_seq[None] == chain_msa, axis=-1) / float(len(query_seq)) | |||||
| per_seq_gap = np.sum(chain_msa == 21, axis=-1) / float(len(query_seq)) | |||||
| msa_df = pd.DataFrame({ | |||||
| 'msa_species_identifiers': | |||||
| chain_features['msa_species_identifiers_all_seq'], | |||||
| 'msa_row': | |||||
| np.arange(len(chain_features['msa_species_identifiers_all_seq'])), | |||||
| 'msa_similarity': | |||||
| per_seq_similarity, | |||||
| 'gap': | |||||
| per_seq_gap, | |||||
| }) | |||||
| return msa_df | |||||
| def _create_species_dict(msa_df: pd.DataFrame) -> Dict[bytes, pd.DataFrame]: | |||||
| """Creates mapping from species to msa dataframe of that species.""" | |||||
| species_lookup = {} | |||||
| for species, species_df in msa_df.groupby('msa_species_identifiers'): | |||||
| species_lookup[species] = species_df | |||||
| return species_lookup | |||||
| def _match_rows_by_sequence_similarity( | |||||
| this_species_msa_dfs: List[pd.DataFrame], ) -> List[List[int]]: # noqa | |||||
| """Finds MSA sequence pairings across chains based on sequence similarity. | |||||
| Each chain's MSA sequences are first sorted by their sequence similarity to | |||||
| their respective target sequence. The sequences are then paired, starting | |||||
| from the sequences most similar to their target sequence. | |||||
| Args: | |||||
| this_species_msa_dfs: a list of dataframes containing MSA features for | |||||
| sequences for a specific species. | |||||
| Returns: | |||||
| A list of lists, each containing M indices corresponding to paired MSA rows, | |||||
| where M is the number of chains. | |||||
| """ | |||||
| all_paired_msa_rows = [] | |||||
| num_seqs = [ | |||||
| len(species_df) for species_df in this_species_msa_dfs | |||||
| if species_df is not None | |||||
| ] | |||||
| take_num_seqs = np.min(num_seqs) | |||||
| # sort_by_similarity = lambda x: x.sort_values( | |||||
| # 'msa_similarity', axis=0, ascending=False) | |||||
| def sort_by_similarity(x): | |||||
| return x.sort_values('msa_similarity', axis=0, ascending=False) | |||||
| for species_df in this_species_msa_dfs: | |||||
| if species_df is not None: | |||||
| species_df_sorted = sort_by_similarity(species_df) | |||||
| msa_rows = species_df_sorted.msa_row.iloc[:take_num_seqs].values | |||||
| else: | |||||
| msa_rows = [-1] * take_num_seqs # take the last 'padding' row | |||||
| all_paired_msa_rows.append(msa_rows) | |||||
| all_paired_msa_rows = list(np.array(all_paired_msa_rows).transpose()) | |||||
| return all_paired_msa_rows | |||||
| def pair_sequences(examples: List[NumpyDict]) -> Dict[int, np.ndarray]: | |||||
| """Returns indices for paired MSA sequences across chains.""" | |||||
| num_examples = len(examples) | |||||
| all_chain_species_dict = [] | |||||
| common_species = set() | |||||
| for chain_features in examples: | |||||
| msa_df = _make_msa_df(chain_features) | |||||
| species_dict = _create_species_dict(msa_df) | |||||
| all_chain_species_dict.append(species_dict) | |||||
| common_species.update(set(species_dict)) | |||||
| common_species = sorted(common_species) | |||||
| common_species.remove(b'') # Remove target sequence species. | |||||
| all_paired_msa_rows = [np.zeros(len(examples), int)] | |||||
| all_paired_msa_rows_dict = {k: [] for k in range(num_examples)} | |||||
| all_paired_msa_rows_dict[num_examples] = [np.zeros(len(examples), int)] | |||||
| for species in common_species: | |||||
| if not species: | |||||
| continue | |||||
| this_species_msa_dfs = [] | |||||
| species_dfs_present = 0 | |||||
| for species_dict in all_chain_species_dict: | |||||
| if species in species_dict: | |||||
| this_species_msa_dfs.append(species_dict[species]) | |||||
| species_dfs_present += 1 | |||||
| else: | |||||
| this_species_msa_dfs.append(None) | |||||
| # Skip species that are present in only one chain. | |||||
| if species_dfs_present <= 1: | |||||
| continue | |||||
| if np.any( | |||||
| np.array([ | |||||
| len(species_df) for species_df in this_species_msa_dfs | |||||
| if isinstance(species_df, pd.DataFrame) | |||||
| ]) > 600): | |||||
| continue | |||||
| paired_msa_rows = _match_rows_by_sequence_similarity( | |||||
| this_species_msa_dfs) | |||||
| all_paired_msa_rows.extend(paired_msa_rows) | |||||
| all_paired_msa_rows_dict[species_dfs_present].extend(paired_msa_rows) | |||||
| all_paired_msa_rows_dict = { | |||||
| num_examples: np.array(paired_msa_rows) | |||||
| for num_examples, paired_msa_rows in all_paired_msa_rows_dict.items() | |||||
| } | |||||
| return all_paired_msa_rows_dict | |||||
| def reorder_paired_rows( | |||||
| all_paired_msa_rows_dict: Dict[int, np.ndarray]) -> np.ndarray: | |||||
| """Creates a list of indices of paired MSA rows across chains. | |||||
| Args: | |||||
| all_paired_msa_rows_dict: a mapping from the number of paired chains to the | |||||
| paired indices. | |||||
| Returns: | |||||
| a list of lists, each containing indices of paired MSA rows across chains. | |||||
| The paired-index lists are ordered by: | |||||
| 1) the number of chains in the paired alignment, i.e, all-chain pairings | |||||
| will come first. | |||||
| 2) e-values | |||||
| """ | |||||
| all_paired_msa_rows = [] | |||||
| for num_pairings in sorted(all_paired_msa_rows_dict, reverse=True): | |||||
| paired_rows = all_paired_msa_rows_dict[num_pairings] | |||||
| paired_rows_product = np.abs( | |||||
| np.array( | |||||
| [np.prod(rows.astype(np.float64)) for rows in paired_rows])) | |||||
| paired_rows_sort_index = np.argsort(paired_rows_product) | |||||
| all_paired_msa_rows.extend(paired_rows[paired_rows_sort_index]) | |||||
| return np.array(all_paired_msa_rows) | |||||
| def block_diag(*arrs: np.ndarray, pad_value: float = 0.0) -> np.ndarray: | |||||
| """Like scipy.linalg.block_diag but with an optional padding value.""" | |||||
| ones_arrs = [np.ones_like(x) for x in arrs] | |||||
| off_diag_mask = 1 - scipy.linalg.block_diag(*ones_arrs) | |||||
| diag = scipy.linalg.block_diag(*arrs) | |||||
| diag += (off_diag_mask * pad_value).astype(diag.dtype) | |||||
| return diag | |||||
| def _correct_post_merged_feats(np_example: NumpyDict, | |||||
| np_chains_list: Sequence[NumpyDict], | |||||
| pair_msa_sequences: bool) -> NumpyDict: | |||||
| """Adds features that need to be computed/recomputed post merging.""" | |||||
| np_example['seq_length'] = np.asarray( | |||||
| np_example['aatype'].shape[0], dtype=np.int32) | |||||
| np_example['num_alignments'] = np.asarray( | |||||
| np_example['msa'].shape[0], dtype=np.int32) | |||||
| if not pair_msa_sequences: | |||||
| # Generate a bias that is 1 for the first row of every block in the | |||||
| # block diagonal MSA - i.e. make sure the cluster stack always includes | |||||
| # the query sequences for each chain (since the first row is the query | |||||
| # sequence). | |||||
| cluster_bias_masks = [] | |||||
| for chain in np_chains_list: | |||||
| mask = np.zeros(chain['msa'].shape[0]) | |||||
| mask[0] = 1 | |||||
| cluster_bias_masks.append(mask) | |||||
| np_example['cluster_bias_mask'] = np.concatenate(cluster_bias_masks) | |||||
| # Initialize Bert mask with masked out off diagonals. | |||||
| msa_masks = [ | |||||
| np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list | |||||
| ] | |||||
| np_example['bert_mask'] = block_diag(*msa_masks, pad_value=0) | |||||
| else: | |||||
| np_example['cluster_bias_mask'] = np.zeros(np_example['msa'].shape[0]) | |||||
| np_example['cluster_bias_mask'][0] = 1 | |||||
| # Initialize Bert mask with masked out off diagonals. | |||||
| msa_masks = [ | |||||
| np.ones(x['msa'].shape, dtype=np.int8) for x in np_chains_list | |||||
| ] | |||||
| msa_masks_all_seq = [ | |||||
| np.ones(x['msa_all_seq'].shape, dtype=np.int8) | |||||
| for x in np_chains_list | |||||
| ] | |||||
| msa_mask_block_diag = block_diag(*msa_masks, pad_value=0) | |||||
| msa_mask_all_seq = np.concatenate(msa_masks_all_seq, axis=1) | |||||
| np_example['bert_mask'] = np.concatenate( | |||||
| [msa_mask_all_seq, msa_mask_block_diag], axis=0) | |||||
| return np_example | |||||
| def _pad_templates(chains: Sequence[NumpyDict], | |||||
| max_templates: int) -> Sequence[NumpyDict]: | |||||
| """For each chain pad the number of templates to a fixed size. | |||||
| Args: | |||||
| chains: A list of protein chains. | |||||
| max_templates: Each chain will be padded to have this many templates. | |||||
| Returns: | |||||
| The list of chains, updated to have template features padded to | |||||
| max_templates. | |||||
| """ | |||||
| for chain in chains: | |||||
| for k, v in chain.items(): | |||||
| if k in TEMPLATE_FEATURES: | |||||
| padding = np.zeros_like(v.shape) | |||||
| padding[0] = max_templates - v.shape[0] | |||||
| padding = [(0, p) for p in padding] | |||||
| chain[k] = np.pad(v, padding, mode='constant') | |||||
| return chains | |||||
| def _merge_features_from_multiple_chains( | |||||
| chains: Sequence[NumpyDict], pair_msa_sequences: bool) -> NumpyDict: | |||||
| """Merge features from multiple chains. | |||||
| Args: | |||||
| chains: A list of feature dictionaries that we want to merge. | |||||
| pair_msa_sequences: Whether to concatenate MSA features along the | |||||
| num_res dimension (if True), or to block diagonalize them (if False). | |||||
| Returns: | |||||
| A feature dictionary for the merged example. | |||||
| """ | |||||
| merged_example = {} | |||||
| for feature_name in chains[0]: | |||||
| feats = [x[feature_name] for x in chains] | |||||
| feature_name_split = feature_name.split('_all_seq')[0] | |||||
| if feature_name_split in MSA_FEATURES: | |||||
| if pair_msa_sequences or '_all_seq' in feature_name: | |||||
| merged_example[feature_name] = np.concatenate(feats, axis=1) | |||||
| if feature_name_split == 'msa': | |||||
| merged_example['msa_chains_all_seq'] = np.ones( | |||||
| merged_example[feature_name].shape[0]).reshape(-1, 1) | |||||
| else: | |||||
| merged_example[feature_name] = block_diag( | |||||
| *feats, pad_value=MSA_PAD_VALUES[feature_name]) | |||||
| if feature_name_split == 'msa': | |||||
| msa_chains = [] | |||||
| for i, feat in enumerate(feats): | |||||
| cur_shape = feat.shape[0] | |||||
| vals = np.ones(cur_shape) * (i + 2) | |||||
| msa_chains.append(vals) | |||||
| merged_example['msa_chains'] = np.concatenate( | |||||
| msa_chains).reshape(-1, 1) | |||||
| elif feature_name_split in SEQ_FEATURES: | |||||
| merged_example[feature_name] = np.concatenate(feats, axis=0) | |||||
| elif feature_name_split in TEMPLATE_FEATURES: | |||||
| merged_example[feature_name] = np.concatenate(feats, axis=1) | |||||
| elif feature_name_split in CHAIN_FEATURES: | |||||
| merged_example[feature_name] = np.sum(feats).astype(np.int32) | |||||
| else: | |||||
| merged_example[feature_name] = feats[0] | |||||
| return merged_example | |||||
| def _merge_homomers_dense_msa( | |||||
| chains: Iterable[NumpyDict]) -> Sequence[NumpyDict]: | |||||
| """Merge all identical chains, making the resulting MSA dense. | |||||
| Args: | |||||
| chains: An iterable of features for each chain. | |||||
| Returns: | |||||
| A list of feature dictionaries. All features with the same entity_id | |||||
| will be merged - MSA features will be concatenated along the num_res | |||||
| dimension - making them dense. | |||||
| """ | |||||
| entity_chains = collections.defaultdict(list) | |||||
| for chain in chains: | |||||
| entity_id = chain['entity_id'][0] | |||||
| entity_chains[entity_id].append(chain) | |||||
| grouped_chains = [] | |||||
| for entity_id in sorted(entity_chains): | |||||
| chains = entity_chains[entity_id] | |||||
| grouped_chains.append(chains) | |||||
| chains = [ | |||||
| _merge_features_from_multiple_chains(chains, pair_msa_sequences=True) | |||||
| for chains in grouped_chains | |||||
| ] | |||||
| return chains | |||||
| def _concatenate_paired_and_unpaired_features(example: NumpyDict) -> NumpyDict: | |||||
| """Merges paired and block-diagonalised features.""" | |||||
| features = MSA_FEATURES + ('msa_chains', ) | |||||
| for feature_name in features: | |||||
| if feature_name in example: | |||||
| feat = example[feature_name] | |||||
| feat_all_seq = example[feature_name + '_all_seq'] | |||||
| try: | |||||
| merged_feat = np.concatenate([feat_all_seq, feat], axis=0) | |||||
| except Exception as ex: | |||||
| raise Exception( | |||||
| 'concat failed.', | |||||
| feature_name, | |||||
| feat_all_seq.shape, | |||||
| feat.shape, | |||||
| ex.__class__, | |||||
| ex, | |||||
| ) | |||||
| example[feature_name] = merged_feat | |||||
| example['num_alignments'] = np.array( | |||||
| example['msa'].shape[0], dtype=np.int32) | |||||
| return example | |||||
| def merge_chain_features(np_chains_list: List[NumpyDict], | |||||
| pair_msa_sequences: bool, | |||||
| max_templates: int) -> NumpyDict: | |||||
| """Merges features for multiple chains to single FeatureDict. | |||||
| Args: | |||||
| np_chains_list: List of FeatureDicts for each chain. | |||||
| pair_msa_sequences: Whether to merge paired MSAs. | |||||
| max_templates: The maximum number of templates to include. | |||||
| Returns: | |||||
| Single FeatureDict for entire complex. | |||||
| """ | |||||
| np_chains_list = _pad_templates( | |||||
| np_chains_list, max_templates=max_templates) | |||||
| np_chains_list = _merge_homomers_dense_msa(np_chains_list) | |||||
| # Unpaired MSA features will be always block-diagonalised; paired MSA | |||||
| # features will be concatenated. | |||||
| np_example = _merge_features_from_multiple_chains( | |||||
| np_chains_list, pair_msa_sequences=False) | |||||
| if pair_msa_sequences: | |||||
| np_example = _concatenate_paired_and_unpaired_features(np_example) | |||||
| np_example = _correct_post_merged_feats( | |||||
| np_example=np_example, | |||||
| np_chains_list=np_chains_list, | |||||
| pair_msa_sequences=pair_msa_sequences, | |||||
| ) | |||||
| return np_example | |||||
| def deduplicate_unpaired_sequences( | |||||
| np_chains: List[NumpyDict]) -> List[NumpyDict]: | |||||
| """Removes unpaired sequences which duplicate a paired sequence.""" | |||||
| feature_names = np_chains[0].keys() | |||||
| msa_features = MSA_FEATURES | |||||
| cache_msa_features = {} | |||||
| for chain in np_chains: | |||||
| entity_id = int(chain['entity_id'][0]) | |||||
| if entity_id not in cache_msa_features: | |||||
| sequence_set = set(s.tobytes() for s in chain['msa_all_seq']) | |||||
| keep_rows = [] | |||||
| # Go through unpaired MSA seqs and remove any rows that correspond to the | |||||
| # sequences that are already present in the paired MSA. | |||||
| for row_num, seq in enumerate(chain['msa']): | |||||
| if seq.tobytes() not in sequence_set: | |||||
| keep_rows.append(row_num) | |||||
| new_msa_features = {} | |||||
| for feature_name in feature_names: | |||||
| if feature_name in msa_features: | |||||
| if keep_rows: | |||||
| new_msa_features[feature_name] = chain[feature_name][ | |||||
| keep_rows] | |||||
| else: | |||||
| new_shape = list(chain[feature_name].shape) | |||||
| new_shape[0] = 0 | |||||
| new_msa_features[feature_name] = np.zeros( | |||||
| new_shape, dtype=chain[feature_name].dtype) | |||||
| cache_msa_features[entity_id] = new_msa_features | |||||
| for feature_name in cache_msa_features[entity_id]: | |||||
| chain[feature_name] = cache_msa_features[entity_id][feature_name] | |||||
| chain['num_alignments'] = np.array( | |||||
| chain['msa'].shape[0], dtype=np.int32) | |||||
| return np_chains | |||||
| @@ -0,0 +1,264 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from typing import Optional | |||||
| import numpy as np | |||||
| import torch | |||||
| from modelscope.models.science.unifold.data import data_ops | |||||
| def nonensembled_fns(common_cfg, mode_cfg): | |||||
| """Input pipeline data transformers that are not ensembled.""" | |||||
| v2_feature = common_cfg.v2_feature | |||||
| operators = [] | |||||
| if mode_cfg.random_delete_msa: | |||||
| operators.append( | |||||
| data_ops.random_delete_msa(common_cfg.random_delete_msa)) | |||||
| operators.extend([ | |||||
| data_ops.cast_to_64bit_ints, | |||||
| data_ops.correct_msa_restypes, | |||||
| data_ops.squeeze_features, | |||||
| data_ops.randomly_replace_msa_with_unknown(0.0), | |||||
| data_ops.make_seq_mask, | |||||
| data_ops.make_msa_mask, | |||||
| ]) | |||||
| operators.append(data_ops.make_hhblits_profile_v2 | |||||
| if v2_feature else data_ops.make_hhblits_profile) | |||||
| if common_cfg.use_templates: | |||||
| operators.extend([ | |||||
| data_ops.make_template_mask, | |||||
| data_ops.make_pseudo_beta('template_'), | |||||
| ]) | |||||
| operators.append( | |||||
| data_ops.crop_templates( | |||||
| max_templates=mode_cfg.max_templates, | |||||
| subsample_templates=mode_cfg.subsample_templates, | |||||
| )) | |||||
| if common_cfg.use_template_torsion_angles: | |||||
| operators.extend([ | |||||
| data_ops.atom37_to_torsion_angles('template_'), | |||||
| ]) | |||||
| operators.append(data_ops.make_atom14_masks) | |||||
| operators.append(data_ops.make_target_feat) | |||||
| return operators | |||||
| def crop_and_fix_size_fns(common_cfg, mode_cfg, crop_and_fix_size_seed): | |||||
| operators = [] | |||||
| if common_cfg.reduce_msa_clusters_by_max_templates: | |||||
| pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates | |||||
| else: | |||||
| pad_msa_clusters = mode_cfg.max_msa_clusters | |||||
| crop_feats = dict(common_cfg.features) | |||||
| if mode_cfg.fixed_size: | |||||
| if mode_cfg.crop: | |||||
| if common_cfg.is_multimer: | |||||
| crop_fn = data_ops.crop_to_size_multimer( | |||||
| crop_size=mode_cfg.crop_size, | |||||
| shape_schema=crop_feats, | |||||
| seed=crop_and_fix_size_seed, | |||||
| spatial_crop_prob=mode_cfg.spatial_crop_prob, | |||||
| ca_ca_threshold=mode_cfg.ca_ca_threshold, | |||||
| ) | |||||
| else: | |||||
| crop_fn = data_ops.crop_to_size_single( | |||||
| crop_size=mode_cfg.crop_size, | |||||
| shape_schema=crop_feats, | |||||
| seed=crop_and_fix_size_seed, | |||||
| ) | |||||
| operators.append(crop_fn) | |||||
| operators.append(data_ops.select_feat(crop_feats)) | |||||
| operators.append( | |||||
| data_ops.make_fixed_size( | |||||
| crop_feats, | |||||
| pad_msa_clusters, | |||||
| common_cfg.max_extra_msa, | |||||
| mode_cfg.crop_size, | |||||
| mode_cfg.max_templates, | |||||
| )) | |||||
| return operators | |||||
| def ensembled_fns(common_cfg, mode_cfg): | |||||
| """Input pipeline data transformers that can be ensembled and averaged.""" | |||||
| operators = [] | |||||
| multimer_mode = common_cfg.is_multimer | |||||
| v2_feature = common_cfg.v2_feature | |||||
| # multimer don't use block delete msa | |||||
| if mode_cfg.block_delete_msa and not multimer_mode: | |||||
| operators.append( | |||||
| data_ops.block_delete_msa(common_cfg.block_delete_msa)) | |||||
| if 'max_distillation_msa_clusters' in mode_cfg: | |||||
| operators.append( | |||||
| data_ops.sample_msa_distillation( | |||||
| mode_cfg.max_distillation_msa_clusters)) | |||||
| if common_cfg.reduce_msa_clusters_by_max_templates: | |||||
| pad_msa_clusters = mode_cfg.max_msa_clusters - mode_cfg.max_templates | |||||
| else: | |||||
| pad_msa_clusters = mode_cfg.max_msa_clusters | |||||
| max_msa_clusters = pad_msa_clusters | |||||
| max_extra_msa = common_cfg.max_extra_msa | |||||
| assert common_cfg.resample_msa_in_recycling | |||||
| gumbel_sample = common_cfg.gumbel_sample | |||||
| operators.append( | |||||
| data_ops.sample_msa( | |||||
| max_msa_clusters, | |||||
| keep_extra=True, | |||||
| gumbel_sample=gumbel_sample, | |||||
| biased_msa_by_chain=mode_cfg.biased_msa_by_chain, | |||||
| )) | |||||
| if 'masked_msa' in common_cfg: | |||||
| # Masked MSA should come *before* MSA clustering so that | |||||
| # the clustering and full MSA profile do not leak information about | |||||
| # the masked locations and secret corrupted locations. | |||||
| operators.append( | |||||
| data_ops.make_masked_msa( | |||||
| common_cfg.masked_msa, | |||||
| mode_cfg.masked_msa_replace_fraction, | |||||
| gumbel_sample=gumbel_sample, | |||||
| share_mask=mode_cfg.share_mask, | |||||
| )) | |||||
| if common_cfg.msa_cluster_features: | |||||
| if v2_feature: | |||||
| operators.append(data_ops.nearest_neighbor_clusters_v2()) | |||||
| else: | |||||
| operators.append(data_ops.nearest_neighbor_clusters()) | |||||
| operators.append(data_ops.summarize_clusters) | |||||
| if v2_feature: | |||||
| operators.append(data_ops.make_msa_feat_v2) | |||||
| else: | |||||
| operators.append(data_ops.make_msa_feat) | |||||
| # Crop after creating the cluster profiles. | |||||
| if max_extra_msa: | |||||
| if v2_feature: | |||||
| operators.append(data_ops.make_extra_msa_feat(max_extra_msa)) | |||||
| else: | |||||
| operators.append(data_ops.crop_extra_msa(max_extra_msa)) | |||||
| else: | |||||
| operators.append(data_ops.delete_extra_msa) | |||||
| # operators.append(data_operators.select_feat(common_cfg.recycling_features)) | |||||
| return operators | |||||
| def process_features(tensors, common_cfg, mode_cfg): | |||||
| """Based on the config, apply filters and transformations to the data.""" | |||||
| is_distillation = bool(tensors.get('is_distillation', 0)) | |||||
| multimer_mode = common_cfg.is_multimer | |||||
| crop_and_fix_size_seed = int(tensors['crop_and_fix_size_seed']) | |||||
| crop_fn = crop_and_fix_size_fns( | |||||
| common_cfg, | |||||
| mode_cfg, | |||||
| crop_and_fix_size_seed, | |||||
| ) | |||||
| def wrap_ensemble_fn(data, i): | |||||
| """Function to be mapped over the ensemble dimension.""" | |||||
| d = data.copy() | |||||
| fns = ensembled_fns( | |||||
| common_cfg, | |||||
| mode_cfg, | |||||
| ) | |||||
| new_d = compose(fns)(d) | |||||
| if not multimer_mode or is_distillation: | |||||
| new_d = data_ops.select_feat(common_cfg.recycling_features)(new_d) | |||||
| return compose(crop_fn)(new_d) | |||||
| else: # select after crop for spatial cropping | |||||
| d = compose(crop_fn)(d) | |||||
| d = data_ops.select_feat(common_cfg.recycling_features)(d) | |||||
| return d | |||||
| nonensembled = nonensembled_fns(common_cfg, mode_cfg) | |||||
| if mode_cfg.supervised and (not multimer_mode or is_distillation): | |||||
| nonensembled.extend(label_transform_fn()) | |||||
| tensors = compose(nonensembled)(tensors) | |||||
| num_recycling = int(tensors['num_recycling_iters']) + 1 | |||||
| num_ensembles = mode_cfg.num_ensembles | |||||
| ensemble_tensors = map_fn( | |||||
| lambda x: wrap_ensemble_fn(tensors, x), | |||||
| torch.arange(num_recycling * num_ensembles), | |||||
| ) | |||||
| tensors = compose(crop_fn)(tensors) | |||||
| # add a dummy dim to align with recycling features | |||||
| tensors = {k: torch.stack([tensors[k]], dim=0) for k in tensors} | |||||
| tensors.update(ensemble_tensors) | |||||
| return tensors | |||||
| @data_ops.curry1 | |||||
| def compose(x, fs): | |||||
| for f in fs: | |||||
| x = f(x) | |||||
| return x | |||||
| def pad_then_stack(values, ): | |||||
| if len(values[0].shape) >= 1: | |||||
| size = max(v.shape[0] for v in values) | |||||
| new_values = [] | |||||
| for v in values: | |||||
| if v.shape[0] < size: | |||||
| res = values[0].new_zeros(size, *v.shape[1:]) | |||||
| res[:v.shape[0], ...] = v | |||||
| else: | |||||
| res = v | |||||
| new_values.append(res) | |||||
| else: | |||||
| new_values = values | |||||
| return torch.stack(new_values, dim=0) | |||||
| def map_fn(fun, x): | |||||
| ensembles = [fun(elem) for elem in x] | |||||
| features = ensembles[0].keys() | |||||
| ensembled_dict = {} | |||||
| for feat in features: | |||||
| ensembled_dict[feat] = pad_then_stack( | |||||
| [dict_i[feat] for dict_i in ensembles]) | |||||
| return ensembled_dict | |||||
| def process_single_label(label: dict, | |||||
| num_ensemble: Optional[int] = None) -> dict: | |||||
| assert 'aatype' in label | |||||
| assert 'all_atom_positions' in label | |||||
| assert 'all_atom_mask' in label | |||||
| label = compose(label_transform_fn())(label) | |||||
| if num_ensemble is not None: | |||||
| label = { | |||||
| k: torch.stack([v for _ in range(num_ensemble)]) | |||||
| for k, v in label.items() | |||||
| } | |||||
| return label | |||||
| def process_labels(labels_list, num_ensemble: Optional[int] = None): | |||||
| return [process_single_label(ll, num_ensemble) for ll in labels_list] | |||||
| def label_transform_fn(): | |||||
| return [ | |||||
| data_ops.make_atom14_masks, | |||||
| data_ops.make_atom14_positions, | |||||
| data_ops.atom37_to_frames, | |||||
| data_ops.atom37_to_torsion_angles(''), | |||||
| data_ops.make_pseudo_beta(''), | |||||
| data_ops.get_backbone_frames, | |||||
| data_ops.get_chi_angles, | |||||
| ] | |||||
| @@ -0,0 +1,417 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Feature processing logic for multimer data """ | |||||
| import collections | |||||
| from typing import Iterable, List, MutableMapping | |||||
| import numpy as np | |||||
| from modelscope.models.science.unifold.data import (msa_pairing, | |||||
| residue_constants) | |||||
| from .utils import correct_template_restypes | |||||
| FeatureDict = MutableMapping[str, np.ndarray] | |||||
| REQUIRED_FEATURES = frozenset({ | |||||
| 'aatype', | |||||
| 'all_atom_mask', | |||||
| 'all_atom_positions', | |||||
| 'all_chains_entity_ids', | |||||
| 'all_crops_all_chains_mask', | |||||
| 'all_crops_all_chains_positions', | |||||
| 'all_crops_all_chains_residue_ids', | |||||
| 'assembly_num_chains', | |||||
| 'asym_id', | |||||
| 'bert_mask', | |||||
| 'cluster_bias_mask', | |||||
| 'deletion_matrix', | |||||
| 'deletion_mean', | |||||
| 'entity_id', | |||||
| 'entity_mask', | |||||
| 'mem_peak', | |||||
| 'msa', | |||||
| 'msa_mask', | |||||
| 'num_alignments', | |||||
| 'num_templates', | |||||
| 'queue_size', | |||||
| 'residue_index', | |||||
| 'resolution', | |||||
| 'seq_length', | |||||
| 'seq_mask', | |||||
| 'sym_id', | |||||
| 'template_aatype', | |||||
| 'template_all_atom_mask', | |||||
| 'template_all_atom_positions', | |||||
| # zy added: | |||||
| 'asym_len', | |||||
| 'template_sum_probs', | |||||
| 'num_sym', | |||||
| 'msa_chains', | |||||
| }) | |||||
| MAX_TEMPLATES = 4 | |||||
| MSA_CROP_SIZE = 2048 | |||||
| def _is_homomer_or_monomer(chains: Iterable[FeatureDict]) -> bool: | |||||
| """Checks if a list of chains represents a homomer/monomer example.""" | |||||
| # Note that an entity_id of 0 indicates padding. | |||||
| num_unique_chains = len( | |||||
| np.unique( | |||||
| np.concatenate([ | |||||
| np.unique(chain['entity_id'][chain['entity_id'] > 0]) | |||||
| for chain in chains | |||||
| ]))) | |||||
| return num_unique_chains == 1 | |||||
| def pair_and_merge( | |||||
| all_chain_features: MutableMapping[str, FeatureDict]) -> FeatureDict: | |||||
| """Runs processing on features to augment, pair and merge. | |||||
| Args: | |||||
| all_chain_features: A MutableMap of dictionaries of features for each chain. | |||||
| Returns: | |||||
| A dictionary of features. | |||||
| """ | |||||
| process_unmerged_features(all_chain_features) | |||||
| np_chains_list = all_chain_features | |||||
| pair_msa_sequences = not _is_homomer_or_monomer(np_chains_list) | |||||
| if pair_msa_sequences: | |||||
| np_chains_list = msa_pairing.create_paired_features( | |||||
| chains=np_chains_list) | |||||
| np_chains_list = msa_pairing.deduplicate_unpaired_sequences( | |||||
| np_chains_list) | |||||
| np_chains_list = crop_chains( | |||||
| np_chains_list, | |||||
| msa_crop_size=MSA_CROP_SIZE, | |||||
| pair_msa_sequences=pair_msa_sequences, | |||||
| max_templates=MAX_TEMPLATES, | |||||
| ) | |||||
| np_example = msa_pairing.merge_chain_features( | |||||
| np_chains_list=np_chains_list, | |||||
| pair_msa_sequences=pair_msa_sequences, | |||||
| max_templates=MAX_TEMPLATES, | |||||
| ) | |||||
| np_example = process_final(np_example) | |||||
| return np_example | |||||
| def crop_chains( | |||||
| chains_list: List[FeatureDict], | |||||
| msa_crop_size: int, | |||||
| pair_msa_sequences: bool, | |||||
| max_templates: int, | |||||
| ) -> List[FeatureDict]: | |||||
| """Crops the MSAs for a set of chains. | |||||
| Args: | |||||
| chains_list: A list of chains to be cropped. | |||||
| msa_crop_size: The total number of sequences to crop from the MSA. | |||||
| pair_msa_sequences: Whether we are operating in sequence-pairing mode. | |||||
| max_templates: The maximum templates to use per chain. | |||||
| Returns: | |||||
| The chains cropped. | |||||
| """ | |||||
| # Apply the cropping. | |||||
| cropped_chains = [] | |||||
| for chain in chains_list: | |||||
| cropped_chain = _crop_single_chain( | |||||
| chain, | |||||
| msa_crop_size=msa_crop_size, | |||||
| pair_msa_sequences=pair_msa_sequences, | |||||
| max_templates=max_templates, | |||||
| ) | |||||
| cropped_chains.append(cropped_chain) | |||||
| return cropped_chains | |||||
| def _crop_single_chain(chain: FeatureDict, msa_crop_size: int, | |||||
| pair_msa_sequences: bool, | |||||
| max_templates: int) -> FeatureDict: | |||||
| """Crops msa sequences to `msa_crop_size`.""" | |||||
| msa_size = chain['num_alignments'] | |||||
| if pair_msa_sequences: | |||||
| msa_size_all_seq = chain['num_alignments_all_seq'] | |||||
| msa_crop_size_all_seq = np.minimum(msa_size_all_seq, | |||||
| msa_crop_size // 2) | |||||
| # We reduce the number of un-paired sequences, by the number of times a | |||||
| # sequence from this chain's MSA is included in the paired MSA. This keeps | |||||
| # the MSA size for each chain roughly constant. | |||||
| msa_all_seq = chain['msa_all_seq'][:msa_crop_size_all_seq, :] | |||||
| num_non_gapped_pairs = np.sum( | |||||
| np.any(msa_all_seq != msa_pairing.MSA_GAP_IDX, axis=1)) | |||||
| num_non_gapped_pairs = np.minimum(num_non_gapped_pairs, | |||||
| msa_crop_size_all_seq) | |||||
| # Restrict the unpaired crop size so that paired+unpaired sequences do not | |||||
| # exceed msa_seqs_per_chain for each chain. | |||||
| max_msa_crop_size = np.maximum(msa_crop_size - num_non_gapped_pairs, 0) | |||||
| msa_crop_size = np.minimum(msa_size, max_msa_crop_size) | |||||
| else: | |||||
| msa_crop_size = np.minimum(msa_size, msa_crop_size) | |||||
| include_templates = 'template_aatype' in chain and max_templates | |||||
| if include_templates: | |||||
| num_templates = chain['template_aatype'].shape[0] | |||||
| templates_crop_size = np.minimum(num_templates, max_templates) | |||||
| for k in chain: | |||||
| k_split = k.split('_all_seq')[0] | |||||
| if k_split in msa_pairing.TEMPLATE_FEATURES: | |||||
| chain[k] = chain[k][:templates_crop_size, :] | |||||
| elif k_split in msa_pairing.MSA_FEATURES: | |||||
| if '_all_seq' in k and pair_msa_sequences: | |||||
| chain[k] = chain[k][:msa_crop_size_all_seq, :] | |||||
| else: | |||||
| chain[k] = chain[k][:msa_crop_size, :] | |||||
| chain['num_alignments'] = np.asarray(msa_crop_size, dtype=np.int32) | |||||
| if include_templates: | |||||
| chain['num_templates'] = np.asarray( | |||||
| templates_crop_size, dtype=np.int32) | |||||
| if pair_msa_sequences: | |||||
| chain['num_alignments_all_seq'] = np.asarray( | |||||
| msa_crop_size_all_seq, dtype=np.int32) | |||||
| return chain | |||||
| def process_final(np_example: FeatureDict) -> FeatureDict: | |||||
| """Final processing steps in data pipeline, after merging and pairing.""" | |||||
| np_example = _make_seq_mask(np_example) | |||||
| np_example = _make_msa_mask(np_example) | |||||
| np_example = _filter_features(np_example) | |||||
| return np_example | |||||
| def _make_seq_mask(np_example): | |||||
| np_example['seq_mask'] = (np_example['entity_id'] > 0).astype(np.float32) | |||||
| return np_example | |||||
| def _make_msa_mask(np_example): | |||||
| """Mask features are all ones, but will later be zero-padded.""" | |||||
| np_example['msa_mask'] = np.ones_like(np_example['msa'], dtype=np.int8) | |||||
| seq_mask = (np_example['entity_id'] > 0).astype(np.int8) | |||||
| np_example['msa_mask'] *= seq_mask[None] | |||||
| return np_example | |||||
| def _filter_features(np_example: FeatureDict) -> FeatureDict: | |||||
| """Filters features of example to only those requested.""" | |||||
| return {k: v for (k, v) in np_example.items() if k in REQUIRED_FEATURES} | |||||
| def process_unmerged_features(all_chain_features: MutableMapping[str, | |||||
| FeatureDict]): | |||||
| """Postprocessing stage for per-chain features before merging.""" | |||||
| num_chains = len(all_chain_features) | |||||
| for chain_features in all_chain_features: | |||||
| # Convert deletion matrices to float. | |||||
| if 'deletion_matrix_int' in chain_features: | |||||
| chain_features['deletion_matrix'] = np.asarray( | |||||
| chain_features.pop('deletion_matrix_int'), dtype=np.float32) | |||||
| if 'deletion_matrix_int_all_seq' in chain_features: | |||||
| chain_features['deletion_matrix_all_seq'] = np.asarray( | |||||
| chain_features.pop('deletion_matrix_int_all_seq'), | |||||
| dtype=np.float32) | |||||
| chain_features['deletion_mean'] = np.mean( | |||||
| chain_features['deletion_matrix'], axis=0) | |||||
| if 'all_atom_positions' not in chain_features: | |||||
| # Add all_atom_mask and dummy all_atom_positions based on aatype. | |||||
| all_atom_mask = residue_constants.STANDARD_ATOM_MASK[ | |||||
| chain_features['aatype']] | |||||
| chain_features['all_atom_mask'] = all_atom_mask | |||||
| chain_features['all_atom_positions'] = np.zeros( | |||||
| list(all_atom_mask.shape) + [3]) | |||||
| # Add assembly_num_chains. | |||||
| chain_features['assembly_num_chains'] = np.asarray(num_chains) | |||||
| # Add entity_mask. | |||||
| for chain_features in all_chain_features: | |||||
| chain_features['entity_mask'] = ( | |||||
| chain_features['entity_id'] != # noqa W504 | |||||
| 0).astype(np.int32) | |||||
| def empty_template_feats(n_res): | |||||
| return { | |||||
| 'template_aatype': | |||||
| np.zeros((0, n_res)).astype(np.int64), | |||||
| 'template_all_atom_positions': | |||||
| np.zeros((0, n_res, 37, 3)).astype(np.float32), | |||||
| 'template_sum_probs': | |||||
| np.zeros((0, 1)).astype(np.float32), | |||||
| 'template_all_atom_mask': | |||||
| np.zeros((0, n_res, 37)).astype(np.float32), | |||||
| } | |||||
| def convert_monomer_features(monomer_features: FeatureDict) -> FeatureDict: | |||||
| """Reshapes and modifies monomer features for multimer models.""" | |||||
| if monomer_features['template_aatype'].shape[0] == 0: | |||||
| monomer_features.update( | |||||
| empty_template_feats(monomer_features['aatype'].shape[0])) | |||||
| converted = {} | |||||
| unnecessary_leading_dim_feats = { | |||||
| 'sequence', | |||||
| 'domain_name', | |||||
| 'num_alignments', | |||||
| 'seq_length', | |||||
| } | |||||
| for feature_name, feature in monomer_features.items(): | |||||
| if feature_name in unnecessary_leading_dim_feats: | |||||
| # asarray ensures it's a np.ndarray. | |||||
| feature = np.asarray(feature[0], dtype=feature.dtype) | |||||
| elif feature_name == 'aatype': | |||||
| # The multimer model performs the one-hot operation itself. | |||||
| feature = np.argmax(feature, axis=-1).astype(np.int32) | |||||
| elif feature_name == 'template_aatype': | |||||
| if feature.shape[0] > 0: | |||||
| feature = correct_template_restypes(feature) | |||||
| elif feature_name == 'template_all_atom_masks': | |||||
| feature_name = 'template_all_atom_mask' | |||||
| elif feature_name == 'msa': | |||||
| feature = feature.astype(np.uint8) | |||||
| if feature_name.endswith('_mask'): | |||||
| feature = feature.astype(np.float32) | |||||
| converted[feature_name] = feature | |||||
| if 'deletion_matrix_int' in monomer_features: | |||||
| monomer_features['deletion_matrix'] = monomer_features.pop( | |||||
| 'deletion_matrix_int').astype(np.float32) | |||||
| converted.pop( | |||||
| 'template_sum_probs' | |||||
| ) # zy: this input is checked to be dirty in shape. TODO: figure out why and make it right. | |||||
| return converted | |||||
| def int_id_to_str_id(num: int) -> str: | |||||
| """Encodes a number as a string, using reverse spreadsheet style naming. | |||||
| Args: | |||||
| num: A positive integer. | |||||
| Returns: | |||||
| A string that encodes the positive integer using reverse spreadsheet style, | |||||
| naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the | |||||
| usual way to encode chain IDs in mmCIF files. | |||||
| """ | |||||
| if num <= 0: | |||||
| raise ValueError(f'Only positive integers allowed, got {num}.') | |||||
| num = num - 1 # 1-based indexing. | |||||
| output = [] | |||||
| while num >= 0: | |||||
| output.append(chr(num % 26 + ord('A'))) | |||||
| num = num // 26 - 1 | |||||
| return ''.join(output) | |||||
| def add_assembly_features(all_chain_features, ): | |||||
| """Add features to distinguish between chains. | |||||
| Args: | |||||
| all_chain_features: A dictionary which maps chain_id to a dictionary of | |||||
| features for each chain. | |||||
| Returns: | |||||
| all_chain_features: A dictionary which maps strings of the form | |||||
| `<seq_id>_<sym_id>` to the corresponding chain features. E.g. two | |||||
| chains from a homodimer would have keys A_1 and A_2. Two chains from a | |||||
| heterodimer would have keys A_1 and B_1. | |||||
| """ | |||||
| # Group the chains by sequence | |||||
| seq_to_entity_id = {} | |||||
| grouped_chains = collections.defaultdict(list) | |||||
| for chain_features in all_chain_features: | |||||
| assert 'sequence' in chain_features | |||||
| seq = str(chain_features['sequence']) | |||||
| if seq not in seq_to_entity_id: | |||||
| seq_to_entity_id[seq] = len(seq_to_entity_id) + 1 | |||||
| grouped_chains[seq_to_entity_id[seq]].append(chain_features) | |||||
| new_all_chain_features = [] | |||||
| chain_id = 1 | |||||
| for entity_id, group_chain_features in grouped_chains.items(): | |||||
| num_sym = len(group_chain_features) # zy | |||||
| for sym_id, chain_features in enumerate(group_chain_features, start=1): | |||||
| seq_length = chain_features['seq_length'] | |||||
| chain_features['asym_id'] = chain_id * np.ones(seq_length) | |||||
| chain_features['sym_id'] = sym_id * np.ones(seq_length) | |||||
| chain_features['entity_id'] = entity_id * np.ones(seq_length) | |||||
| chain_features['num_sym'] = num_sym * np.ones(seq_length) | |||||
| chain_id += 1 | |||||
| new_all_chain_features.append(chain_features) | |||||
| return new_all_chain_features | |||||
| def pad_msa(np_example, min_num_seq): | |||||
| np_example = dict(np_example) | |||||
| num_seq = np_example['msa'].shape[0] | |||||
| if num_seq < min_num_seq: | |||||
| for feat in ('msa', 'deletion_matrix', 'bert_mask', 'msa_mask', | |||||
| 'msa_chains'): | |||||
| np_example[feat] = np.pad(np_example[feat], | |||||
| ((0, min_num_seq - num_seq), (0, 0))) | |||||
| np_example['cluster_bias_mask'] = np.pad( | |||||
| np_example['cluster_bias_mask'], ((0, min_num_seq - num_seq), )) | |||||
| return np_example | |||||
| def post_process(np_example): | |||||
| np_example = pad_msa(np_example, 512) | |||||
| no_dim_keys = [ | |||||
| 'num_alignments', | |||||
| 'assembly_num_chains', | |||||
| 'num_templates', | |||||
| 'seq_length', | |||||
| 'resolution', | |||||
| ] | |||||
| for k in no_dim_keys: | |||||
| if k in np_example: | |||||
| np_example[k] = np_example[k].reshape(-1) | |||||
| return np_example | |||||
| def merge_msas(msa, del_mat, new_msa, new_del_mat): | |||||
| cur_msa_set = set([tuple(m) for m in msa]) | |||||
| new_rows = [] | |||||
| for i, s in enumerate(new_msa): | |||||
| if tuple(s) not in cur_msa_set: | |||||
| new_rows.append(i) | |||||
| ret_msa = np.concatenate([msa, new_msa[new_rows]], axis=0) | |||||
| ret_del_mat = np.concatenate([del_mat, new_del_mat[new_rows]], axis=0) | |||||
| return ret_msa, ret_del_mat | |||||
| @@ -0,0 +1,322 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Protein data type.""" | |||||
| import dataclasses | |||||
| import io | |||||
| from typing import Any, Mapping, Optional | |||||
| import numpy as np | |||||
| from Bio.PDB import PDBParser | |||||
| from modelscope.models.science.unifold.data import residue_constants | |||||
| FeatureDict = Mapping[str, np.ndarray] | |||||
| ModelOutput = Mapping[str, Any] # Is a nested dict. | |||||
| # Complete sequence of chain IDs supported by the PDB format. | |||||
| PDB_CHAIN_IDS = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' | |||||
| PDB_MAX_CHAINS = len(PDB_CHAIN_IDS) # := 62. | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class Protein: | |||||
| """Protein structure representation.""" | |||||
| # Cartesian coordinates of atoms in angstroms. The atom types correspond to | |||||
| # residue_constants.atom_types, i.e. the first three are N, CA, CB. | |||||
| atom_positions: np.ndarray # [num_res, num_atom_type, 3] | |||||
| # Amino-acid type for each residue represented as an integer between 0 and | |||||
| # 20, where 20 is 'X'. | |||||
| aatype: np.ndarray # [num_res] | |||||
| # Binary float mask to indicate presence of a particular atom. 1.0 if an atom | |||||
| # is present and 0.0 if not. This should be used for loss masking. | |||||
| atom_mask: np.ndarray # [num_res, num_atom_type] | |||||
| # Residue index as used in PDB. It is not necessarily continuous or 0-indexed. | |||||
| residue_index: np.ndarray # [num_res] | |||||
| # 0-indexed number corresponding to the chain in the protein that this residue | |||||
| # belongs to. | |||||
| chain_index: np.ndarray # [num_res] | |||||
| # B-factors, or temperature factors, of each residue (in sq. angstroms units), | |||||
| # representing the displacement of the residue from its ground truth mean | |||||
| # value. | |||||
| b_factors: np.ndarray # [num_res, num_atom_type] | |||||
| def __post_init__(self): | |||||
| if len(np.unique(self.chain_index)) > PDB_MAX_CHAINS: | |||||
| raise ValueError( | |||||
| f'Cannot build an instance with more than {PDB_MAX_CHAINS} chains ' | |||||
| 'because these cannot be written to PDB format.') | |||||
| def from_pdb_string(pdb_str: str, chain_id: Optional[str] = None) -> Protein: | |||||
| """Takes a PDB string and constructs a Protein object. | |||||
| WARNING: All non-standard residue types will be converted into UNK. All | |||||
| non-standard atoms will be ignored. | |||||
| Args: | |||||
| pdb_str: The contents of the pdb file | |||||
| chain_id: If chain_id is specified (e.g. A), then only that chain | |||||
| is parsed. Otherwise all chains are parsed. | |||||
| Returns: | |||||
| A new `Protein` parsed from the pdb contents. | |||||
| """ | |||||
| pdb_fh = io.StringIO(pdb_str) | |||||
| parser = PDBParser(QUIET=True) | |||||
| structure = parser.get_structure('none', pdb_fh) | |||||
| models = list(structure.get_models()) | |||||
| if len(models) != 1: | |||||
| raise ValueError( | |||||
| f'Only single model PDBs are supported. Found {len(models)} models.' | |||||
| ) | |||||
| model = models[0] | |||||
| atom_positions = [] | |||||
| aatype = [] | |||||
| atom_mask = [] | |||||
| residue_index = [] | |||||
| chain_ids = [] | |||||
| b_factors = [] | |||||
| for chain in model: | |||||
| if chain_id is not None and chain.id != chain_id: | |||||
| continue | |||||
| for res in chain: | |||||
| if res.id[2] != ' ': | |||||
| raise ValueError( | |||||
| f'PDB contains an insertion code at chain {chain.id} and residue ' | |||||
| f'index {res.id[1]}. These are not supported.') | |||||
| res_shortname = residue_constants.restype_3to1.get( | |||||
| res.resname, 'X') | |||||
| restype_idx = residue_constants.restype_order.get( | |||||
| res_shortname, residue_constants.restype_num) | |||||
| pos = np.zeros((residue_constants.atom_type_num, 3)) | |||||
| mask = np.zeros((residue_constants.atom_type_num, )) | |||||
| res_b_factors = np.zeros((residue_constants.atom_type_num, )) | |||||
| for atom in res: | |||||
| if atom.name not in residue_constants.atom_types: | |||||
| continue | |||||
| pos[residue_constants.atom_order[atom.name]] = atom.coord | |||||
| mask[residue_constants.atom_order[atom.name]] = 1.0 | |||||
| res_b_factors[residue_constants.atom_order[ | |||||
| atom.name]] = atom.bfactor | |||||
| if np.sum(mask) < 0.5: | |||||
| # If no known atom positions are reported for the residue then skip it. | |||||
| continue | |||||
| aatype.append(restype_idx) | |||||
| atom_positions.append(pos) | |||||
| atom_mask.append(mask) | |||||
| residue_index.append(res.id[1]) | |||||
| chain_ids.append(chain.id) | |||||
| b_factors.append(res_b_factors) | |||||
| # Chain IDs are usually characters so map these to ints. | |||||
| unique_chain_ids = np.unique(chain_ids) | |||||
| chain_id_mapping = {cid: n for n, cid in enumerate(unique_chain_ids)} | |||||
| chain_index = np.array([chain_id_mapping[cid] for cid in chain_ids]) | |||||
| return Protein( | |||||
| atom_positions=np.array(atom_positions), | |||||
| atom_mask=np.array(atom_mask), | |||||
| aatype=np.array(aatype), | |||||
| residue_index=np.array(residue_index), | |||||
| chain_index=chain_index, | |||||
| b_factors=np.array(b_factors), | |||||
| ) | |||||
| def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str: | |||||
| chain_end = 'TER' | |||||
| return (f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' | |||||
| f'{chain_name:>1}{residue_index:>4}') | |||||
| def to_pdb(prot: Protein) -> str: | |||||
| """Converts a `Protein` instance to a PDB string. | |||||
| Args: | |||||
| prot: The protein to convert to PDB. | |||||
| Returns: | |||||
| PDB string. | |||||
| """ | |||||
| restypes = residue_constants.restypes + ['X'] | |||||
| # res_1to3 = lambda r: residue_constants.restype_1to3.get(restypes[r], 'UNK') | |||||
| def res_1to3(r): | |||||
| return residue_constants.restype_1to3.get(restypes[r], 'UNK') | |||||
| atom_types = residue_constants.atom_types | |||||
| pdb_lines = [] | |||||
| atom_mask = prot.atom_mask | |||||
| aatype = prot.aatype | |||||
| atom_positions = prot.atom_positions | |||||
| residue_index = prot.residue_index.astype(np.int32) | |||||
| chain_index = prot.chain_index.astype(np.int32) | |||||
| b_factors = prot.b_factors | |||||
| if np.any(aatype > residue_constants.restype_num): | |||||
| raise ValueError('Invalid aatypes.') | |||||
| # Construct a mapping from chain integer indices to chain ID strings. | |||||
| chain_ids = {} | |||||
| for i in np.unique(chain_index): # np.unique gives sorted output. | |||||
| if i >= PDB_MAX_CHAINS: | |||||
| raise ValueError( | |||||
| f'The PDB format supports at most {PDB_MAX_CHAINS} chains.') | |||||
| chain_ids[i] = PDB_CHAIN_IDS[i] | |||||
| pdb_lines.append('MODEL 1') | |||||
| atom_index = 1 | |||||
| last_chain_index = chain_index[0] | |||||
| # Add all atom sites. | |||||
| for i in range(aatype.shape[0]): | |||||
| # Close the previous chain if in a multichain PDB. | |||||
| if last_chain_index != chain_index[i]: | |||||
| pdb_lines.append( | |||||
| _chain_end( | |||||
| atom_index, | |||||
| res_1to3(aatype[i - 1]), | |||||
| chain_ids[chain_index[i - 1]], | |||||
| residue_index[i - 1], | |||||
| )) | |||||
| last_chain_index = chain_index[i] | |||||
| atom_index += 1 # Atom index increases at the TER symbol. | |||||
| res_name_3 = res_1to3(aatype[i]) | |||||
| for atom_name, pos, mask, b_factor in zip(atom_types, | |||||
| atom_positions[i], | |||||
| atom_mask[i], b_factors[i]): | |||||
| if mask < 0.5: | |||||
| continue | |||||
| record_type = 'ATOM' | |||||
| name = atom_name if len(atom_name) == 4 else f' {atom_name}' | |||||
| alt_loc = '' | |||||
| insertion_code = '' | |||||
| occupancy = 1.00 | |||||
| element = atom_name[ | |||||
| 0] # Protein supports only C, N, O, S, this works. | |||||
| charge = '' | |||||
| # PDB is a columnar format, every space matters here! | |||||
| atom_line = ( | |||||
| f'{record_type:<6}{atom_index:>5} {name:<4}{alt_loc:>1}' | |||||
| f'{res_name_3:>3} {chain_ids[chain_index[i]]:>1}' | |||||
| f'{residue_index[i]:>4}{insertion_code:>1} ' | |||||
| f'{pos[0]:>8.3f}{pos[1]:>8.3f}{pos[2]:>8.3f}' | |||||
| f'{occupancy:>6.2f}{b_factor:>6.2f} ' | |||||
| f'{element:>2}{charge:>2}') | |||||
| pdb_lines.append(atom_line) | |||||
| atom_index += 1 | |||||
| # Close the final chain. | |||||
| pdb_lines.append( | |||||
| _chain_end( | |||||
| atom_index, | |||||
| res_1to3(aatype[-1]), | |||||
| chain_ids[chain_index[-1]], | |||||
| residue_index[-1], | |||||
| )) | |||||
| pdb_lines.append('ENDMDL') | |||||
| pdb_lines.append('END') | |||||
| # Pad all lines to 80 characters. | |||||
| pdb_lines = [line.ljust(80) for line in pdb_lines] | |||||
| return '\n'.join(pdb_lines) + '\n' # Add terminating newline. | |||||
| def ideal_atom_mask(prot: Protein) -> np.ndarray: | |||||
| """Computes an ideal atom mask. | |||||
| `Protein.atom_mask` typically is defined according to the atoms that are | |||||
| reported in the PDB. This function computes a mask according to heavy atoms | |||||
| that should be present in the given sequence of amino acids. | |||||
| Args: | |||||
| prot: `Protein` whose fields are `numpy.ndarray` objects. | |||||
| Returns: | |||||
| An ideal atom mask. | |||||
| """ | |||||
| return residue_constants.STANDARD_ATOM_MASK[prot.aatype] | |||||
| def from_prediction(features: FeatureDict, | |||||
| result: ModelOutput, | |||||
| b_factors: Optional[np.ndarray] = None) -> Protein: | |||||
| """Assembles a protein from a prediction. | |||||
| Args: | |||||
| features: Dictionary holding model inputs. | |||||
| fold_output: Dictionary holding model outputs. | |||||
| b_factors: (Optional) B-factors to use for the protein. | |||||
| Returns: | |||||
| A protein instance. | |||||
| """ | |||||
| if 'asym_id' in features: | |||||
| chain_index = features['asym_id'] - 1 | |||||
| else: | |||||
| chain_index = np.zeros_like((features['aatype'])) | |||||
| if b_factors is None: | |||||
| b_factors = np.zeros_like(result['final_atom_mask']) | |||||
| return Protein( | |||||
| aatype=features['aatype'], | |||||
| atom_positions=result['final_atom_positions'], | |||||
| atom_mask=result['final_atom_mask'], | |||||
| residue_index=features['residue_index'] + 1, | |||||
| chain_index=chain_index, | |||||
| b_factors=b_factors, | |||||
| ) | |||||
| def from_feature(features: FeatureDict, | |||||
| b_factors: Optional[np.ndarray] = None) -> Protein: | |||||
| """Assembles a standard pdb from input atom positions & mask. | |||||
| Args: | |||||
| features: Dictionary holding model inputs. | |||||
| b_factors: (Optional) B-factors to use for the protein. | |||||
| Returns: | |||||
| A protein instance. | |||||
| """ | |||||
| if 'asym_id' in features: | |||||
| chain_index = features['asym_id'] - 1 | |||||
| else: | |||||
| chain_index = np.zeros_like((features['aatype'])) | |||||
| if b_factors is None: | |||||
| b_factors = np.zeros_like(features['all_atom_mask']) | |||||
| return Protein( | |||||
| aatype=features['aatype'], | |||||
| atom_positions=features['all_atom_positions'], | |||||
| atom_mask=features['all_atom_mask'], | |||||
| residue_index=features['residue_index'] + 1, | |||||
| chain_index=chain_index, | |||||
| b_factors=b_factors, | |||||
| ) | |||||
| @@ -0,0 +1,345 @@ | |||||
| Bond Residue Mean StdDev | |||||
| CA-CB ALA 1.520 0.021 | |||||
| N-CA ALA 1.459 0.020 | |||||
| CA-C ALA 1.525 0.026 | |||||
| C-O ALA 1.229 0.019 | |||||
| CA-CB ARG 1.535 0.022 | |||||
| CB-CG ARG 1.521 0.027 | |||||
| CG-CD ARG 1.515 0.025 | |||||
| CD-NE ARG 1.460 0.017 | |||||
| NE-CZ ARG 1.326 0.013 | |||||
| CZ-NH1 ARG 1.326 0.013 | |||||
| CZ-NH2 ARG 1.326 0.013 | |||||
| N-CA ARG 1.459 0.020 | |||||
| CA-C ARG 1.525 0.026 | |||||
| C-O ARG 1.229 0.019 | |||||
| CA-CB ASN 1.527 0.026 | |||||
| CB-CG ASN 1.506 0.023 | |||||
| CG-OD1 ASN 1.235 0.022 | |||||
| CG-ND2 ASN 1.324 0.025 | |||||
| N-CA ASN 1.459 0.020 | |||||
| CA-C ASN 1.525 0.026 | |||||
| C-O ASN 1.229 0.019 | |||||
| CA-CB ASP 1.535 0.022 | |||||
| CB-CG ASP 1.513 0.021 | |||||
| CG-OD1 ASP 1.249 0.023 | |||||
| CG-OD2 ASP 1.249 0.023 | |||||
| N-CA ASP 1.459 0.020 | |||||
| CA-C ASP 1.525 0.026 | |||||
| C-O ASP 1.229 0.019 | |||||
| CA-CB CYS 1.526 0.013 | |||||
| CB-SG CYS 1.812 0.016 | |||||
| N-CA CYS 1.459 0.020 | |||||
| CA-C CYS 1.525 0.026 | |||||
| C-O CYS 1.229 0.019 | |||||
| CA-CB GLU 1.535 0.022 | |||||
| CB-CG GLU 1.517 0.019 | |||||
| CG-CD GLU 1.515 0.015 | |||||
| CD-OE1 GLU 1.252 0.011 | |||||
| CD-OE2 GLU 1.252 0.011 | |||||
| N-CA GLU 1.459 0.020 | |||||
| CA-C GLU 1.525 0.026 | |||||
| C-O GLU 1.229 0.019 | |||||
| CA-CB GLN 1.535 0.022 | |||||
| CB-CG GLN 1.521 0.027 | |||||
| CG-CD GLN 1.506 0.023 | |||||
| CD-OE1 GLN 1.235 0.022 | |||||
| CD-NE2 GLN 1.324 0.025 | |||||
| N-CA GLN 1.459 0.020 | |||||
| CA-C GLN 1.525 0.026 | |||||
| C-O GLN 1.229 0.019 | |||||
| N-CA GLY 1.456 0.015 | |||||
| CA-C GLY 1.514 0.016 | |||||
| C-O GLY 1.232 0.016 | |||||
| CA-CB HIS 1.535 0.022 | |||||
| CB-CG HIS 1.492 0.016 | |||||
| CG-ND1 HIS 1.369 0.015 | |||||
| CG-CD2 HIS 1.353 0.017 | |||||
| ND1-CE1 HIS 1.343 0.025 | |||||
| CD2-NE2 HIS 1.415 0.021 | |||||
| CE1-NE2 HIS 1.322 0.023 | |||||
| N-CA HIS 1.459 0.020 | |||||
| CA-C HIS 1.525 0.026 | |||||
| C-O HIS 1.229 0.019 | |||||
| CA-CB ILE 1.544 0.023 | |||||
| CB-CG1 ILE 1.536 0.028 | |||||
| CB-CG2 ILE 1.524 0.031 | |||||
| CG1-CD1 ILE 1.500 0.069 | |||||
| N-CA ILE 1.459 0.020 | |||||
| CA-C ILE 1.525 0.026 | |||||
| C-O ILE 1.229 0.019 | |||||
| CA-CB LEU 1.533 0.023 | |||||
| CB-CG LEU 1.521 0.029 | |||||
| CG-CD1 LEU 1.514 0.037 | |||||
| CG-CD2 LEU 1.514 0.037 | |||||
| N-CA LEU 1.459 0.020 | |||||
| CA-C LEU 1.525 0.026 | |||||
| C-O LEU 1.229 0.019 | |||||
| CA-CB LYS 1.535 0.022 | |||||
| CB-CG LYS 1.521 0.027 | |||||
| CG-CD LYS 1.520 0.034 | |||||
| CD-CE LYS 1.508 0.025 | |||||
| CE-NZ LYS 1.486 0.025 | |||||
| N-CA LYS 1.459 0.020 | |||||
| CA-C LYS 1.525 0.026 | |||||
| C-O LYS 1.229 0.019 | |||||
| CA-CB MET 1.535 0.022 | |||||
| CB-CG MET 1.509 0.032 | |||||
| CG-SD MET 1.807 0.026 | |||||
| SD-CE MET 1.774 0.056 | |||||
| N-CA MET 1.459 0.020 | |||||
| CA-C MET 1.525 0.026 | |||||
| C-O MET 1.229 0.019 | |||||
| CA-CB PHE 1.535 0.022 | |||||
| CB-CG PHE 1.509 0.017 | |||||
| CG-CD1 PHE 1.383 0.015 | |||||
| CG-CD2 PHE 1.383 0.015 | |||||
| CD1-CE1 PHE 1.388 0.020 | |||||
| CD2-CE2 PHE 1.388 0.020 | |||||
| CE1-CZ PHE 1.369 0.019 | |||||
| CE2-CZ PHE 1.369 0.019 | |||||
| N-CA PHE 1.459 0.020 | |||||
| CA-C PHE 1.525 0.026 | |||||
| C-O PHE 1.229 0.019 | |||||
| CA-CB PRO 1.531 0.020 | |||||
| CB-CG PRO 1.495 0.050 | |||||
| CG-CD PRO 1.502 0.033 | |||||
| CD-N PRO 1.474 0.014 | |||||
| N-CA PRO 1.468 0.017 | |||||
| CA-C PRO 1.524 0.020 | |||||
| C-O PRO 1.228 0.020 | |||||
| CA-CB SER 1.525 0.015 | |||||
| CB-OG SER 1.418 0.013 | |||||
| N-CA SER 1.459 0.020 | |||||
| CA-C SER 1.525 0.026 | |||||
| C-O SER 1.229 0.019 | |||||
| CA-CB THR 1.529 0.026 | |||||
| CB-OG1 THR 1.428 0.020 | |||||
| CB-CG2 THR 1.519 0.033 | |||||
| N-CA THR 1.459 0.020 | |||||
| CA-C THR 1.525 0.026 | |||||
| C-O THR 1.229 0.019 | |||||
| CA-CB TRP 1.535 0.022 | |||||
| CB-CG TRP 1.498 0.018 | |||||
| CG-CD1 TRP 1.363 0.014 | |||||
| CG-CD2 TRP 1.432 0.017 | |||||
| CD1-NE1 TRP 1.375 0.017 | |||||
| NE1-CE2 TRP 1.371 0.013 | |||||
| CD2-CE2 TRP 1.409 0.012 | |||||
| CD2-CE3 TRP 1.399 0.015 | |||||
| CE2-CZ2 TRP 1.393 0.017 | |||||
| CE3-CZ3 TRP 1.380 0.017 | |||||
| CZ2-CH2 TRP 1.369 0.019 | |||||
| CZ3-CH2 TRP 1.396 0.016 | |||||
| N-CA TRP 1.459 0.020 | |||||
| CA-C TRP 1.525 0.026 | |||||
| C-O TRP 1.229 0.019 | |||||
| CA-CB TYR 1.535 0.022 | |||||
| CB-CG TYR 1.512 0.015 | |||||
| CG-CD1 TYR 1.387 0.013 | |||||
| CG-CD2 TYR 1.387 0.013 | |||||
| CD1-CE1 TYR 1.389 0.015 | |||||
| CD2-CE2 TYR 1.389 0.015 | |||||
| CE1-CZ TYR 1.381 0.013 | |||||
| CE2-CZ TYR 1.381 0.013 | |||||
| CZ-OH TYR 1.374 0.017 | |||||
| N-CA TYR 1.459 0.020 | |||||
| CA-C TYR 1.525 0.026 | |||||
| C-O TYR 1.229 0.019 | |||||
| CA-CB VAL 1.543 0.021 | |||||
| CB-CG1 VAL 1.524 0.021 | |||||
| CB-CG2 VAL 1.524 0.021 | |||||
| N-CA VAL 1.459 0.020 | |||||
| CA-C VAL 1.525 0.026 | |||||
| C-O VAL 1.229 0.019 | |||||
| - | |||||
| Angle Residue Mean StdDev | |||||
| N-CA-CB ALA 110.1 1.4 | |||||
| CB-CA-C ALA 110.1 1.5 | |||||
| N-CA-C ALA 111.0 2.7 | |||||
| CA-C-O ALA 120.1 2.1 | |||||
| N-CA-CB ARG 110.6 1.8 | |||||
| CB-CA-C ARG 110.4 2.0 | |||||
| CA-CB-CG ARG 113.4 2.2 | |||||
| CB-CG-CD ARG 111.6 2.6 | |||||
| CG-CD-NE ARG 111.8 2.1 | |||||
| CD-NE-CZ ARG 123.6 1.4 | |||||
| NE-CZ-NH1 ARG 120.3 0.5 | |||||
| NE-CZ-NH2 ARG 120.3 0.5 | |||||
| NH1-CZ-NH2 ARG 119.4 1.1 | |||||
| N-CA-C ARG 111.0 2.7 | |||||
| CA-C-O ARG 120.1 2.1 | |||||
| N-CA-CB ASN 110.6 1.8 | |||||
| CB-CA-C ASN 110.4 2.0 | |||||
| CA-CB-CG ASN 113.4 2.2 | |||||
| CB-CG-ND2 ASN 116.7 2.4 | |||||
| CB-CG-OD1 ASN 121.6 2.0 | |||||
| ND2-CG-OD1 ASN 121.9 2.3 | |||||
| N-CA-C ASN 111.0 2.7 | |||||
| CA-C-O ASN 120.1 2.1 | |||||
| N-CA-CB ASP 110.6 1.8 | |||||
| CB-CA-C ASP 110.4 2.0 | |||||
| CA-CB-CG ASP 113.4 2.2 | |||||
| CB-CG-OD1 ASP 118.3 0.9 | |||||
| CB-CG-OD2 ASP 118.3 0.9 | |||||
| OD1-CG-OD2 ASP 123.3 1.9 | |||||
| N-CA-C ASP 111.0 2.7 | |||||
| CA-C-O ASP 120.1 2.1 | |||||
| N-CA-CB CYS 110.8 1.5 | |||||
| CB-CA-C CYS 111.5 1.2 | |||||
| CA-CB-SG CYS 114.2 1.1 | |||||
| N-CA-C CYS 111.0 2.7 | |||||
| CA-C-O CYS 120.1 2.1 | |||||
| N-CA-CB GLU 110.6 1.8 | |||||
| CB-CA-C GLU 110.4 2.0 | |||||
| CA-CB-CG GLU 113.4 2.2 | |||||
| CB-CG-CD GLU 114.2 2.7 | |||||
| CG-CD-OE1 GLU 118.3 2.0 | |||||
| CG-CD-OE2 GLU 118.3 2.0 | |||||
| OE1-CD-OE2 GLU 123.3 1.2 | |||||
| N-CA-C GLU 111.0 2.7 | |||||
| CA-C-O GLU 120.1 2.1 | |||||
| N-CA-CB GLN 110.6 1.8 | |||||
| CB-CA-C GLN 110.4 2.0 | |||||
| CA-CB-CG GLN 113.4 2.2 | |||||
| CB-CG-CD GLN 111.6 2.6 | |||||
| CG-CD-OE1 GLN 121.6 2.0 | |||||
| CG-CD-NE2 GLN 116.7 2.4 | |||||
| OE1-CD-NE2 GLN 121.9 2.3 | |||||
| N-CA-C GLN 111.0 2.7 | |||||
| CA-C-O GLN 120.1 2.1 | |||||
| N-CA-C GLY 113.1 2.5 | |||||
| CA-C-O GLY 120.6 1.8 | |||||
| N-CA-CB HIS 110.6 1.8 | |||||
| CB-CA-C HIS 110.4 2.0 | |||||
| CA-CB-CG HIS 113.6 1.7 | |||||
| CB-CG-ND1 HIS 123.2 2.5 | |||||
| CB-CG-CD2 HIS 130.8 3.1 | |||||
| CG-ND1-CE1 HIS 108.2 1.4 | |||||
| ND1-CE1-NE2 HIS 109.9 2.2 | |||||
| CE1-NE2-CD2 HIS 106.6 2.5 | |||||
| NE2-CD2-CG HIS 109.2 1.9 | |||||
| CD2-CG-ND1 HIS 106.0 1.4 | |||||
| N-CA-C HIS 111.0 2.7 | |||||
| CA-C-O HIS 120.1 2.1 | |||||
| N-CA-CB ILE 110.8 2.3 | |||||
| CB-CA-C ILE 111.6 2.0 | |||||
| CA-CB-CG1 ILE 111.0 1.9 | |||||
| CB-CG1-CD1 ILE 113.9 2.8 | |||||
| CA-CB-CG2 ILE 110.9 2.0 | |||||
| CG1-CB-CG2 ILE 111.4 2.2 | |||||
| N-CA-C ILE 111.0 2.7 | |||||
| CA-C-O ILE 120.1 2.1 | |||||
| N-CA-CB LEU 110.4 2.0 | |||||
| CB-CA-C LEU 110.2 1.9 | |||||
| CA-CB-CG LEU 115.3 2.3 | |||||
| CB-CG-CD1 LEU 111.0 1.7 | |||||
| CB-CG-CD2 LEU 111.0 1.7 | |||||
| CD1-CG-CD2 LEU 110.5 3.0 | |||||
| N-CA-C LEU 111.0 2.7 | |||||
| CA-C-O LEU 120.1 2.1 | |||||
| N-CA-CB LYS 110.6 1.8 | |||||
| CB-CA-C LYS 110.4 2.0 | |||||
| CA-CB-CG LYS 113.4 2.2 | |||||
| CB-CG-CD LYS 111.6 2.6 | |||||
| CG-CD-CE LYS 111.9 3.0 | |||||
| CD-CE-NZ LYS 111.7 2.3 | |||||
| N-CA-C LYS 111.0 2.7 | |||||
| CA-C-O LYS 120.1 2.1 | |||||
| N-CA-CB MET 110.6 1.8 | |||||
| CB-CA-C MET 110.4 2.0 | |||||
| CA-CB-CG MET 113.3 1.7 | |||||
| CB-CG-SD MET 112.4 3.0 | |||||
| CG-SD-CE MET 100.2 1.6 | |||||
| N-CA-C MET 111.0 2.7 | |||||
| CA-C-O MET 120.1 2.1 | |||||
| N-CA-CB PHE 110.6 1.8 | |||||
| CB-CA-C PHE 110.4 2.0 | |||||
| CA-CB-CG PHE 113.9 2.4 | |||||
| CB-CG-CD1 PHE 120.8 0.7 | |||||
| CB-CG-CD2 PHE 120.8 0.7 | |||||
| CD1-CG-CD2 PHE 118.3 1.3 | |||||
| CG-CD1-CE1 PHE 120.8 1.1 | |||||
| CG-CD2-CE2 PHE 120.8 1.1 | |||||
| CD1-CE1-CZ PHE 120.1 1.2 | |||||
| CD2-CE2-CZ PHE 120.1 1.2 | |||||
| CE1-CZ-CE2 PHE 120.0 1.8 | |||||
| N-CA-C PHE 111.0 2.7 | |||||
| CA-C-O PHE 120.1 2.1 | |||||
| N-CA-CB PRO 103.3 1.2 | |||||
| CB-CA-C PRO 111.7 2.1 | |||||
| CA-CB-CG PRO 104.8 1.9 | |||||
| CB-CG-CD PRO 106.5 3.9 | |||||
| CG-CD-N PRO 103.2 1.5 | |||||
| CA-N-CD PRO 111.7 1.4 | |||||
| N-CA-C PRO 112.1 2.6 | |||||
| CA-C-O PRO 120.2 2.4 | |||||
| N-CA-CB SER 110.5 1.5 | |||||
| CB-CA-C SER 110.1 1.9 | |||||
| CA-CB-OG SER 111.2 2.7 | |||||
| N-CA-C SER 111.0 2.7 | |||||
| CA-C-O SER 120.1 2.1 | |||||
| N-CA-CB THR 110.3 1.9 | |||||
| CB-CA-C THR 111.6 2.7 | |||||
| CA-CB-OG1 THR 109.0 2.1 | |||||
| CA-CB-CG2 THR 112.4 1.4 | |||||
| OG1-CB-CG2 THR 110.0 2.3 | |||||
| N-CA-C THR 111.0 2.7 | |||||
| CA-C-O THR 120.1 2.1 | |||||
| N-CA-CB TRP 110.6 1.8 | |||||
| CB-CA-C TRP 110.4 2.0 | |||||
| CA-CB-CG TRP 113.7 1.9 | |||||
| CB-CG-CD1 TRP 127.0 1.3 | |||||
| CB-CG-CD2 TRP 126.6 1.3 | |||||
| CD1-CG-CD2 TRP 106.3 0.8 | |||||
| CG-CD1-NE1 TRP 110.1 1.0 | |||||
| CD1-NE1-CE2 TRP 109.0 0.9 | |||||
| NE1-CE2-CD2 TRP 107.3 1.0 | |||||
| CE2-CD2-CG TRP 107.3 0.8 | |||||
| CG-CD2-CE3 TRP 133.9 0.9 | |||||
| NE1-CE2-CZ2 TRP 130.4 1.1 | |||||
| CE3-CD2-CE2 TRP 118.7 1.2 | |||||
| CD2-CE2-CZ2 TRP 122.3 1.2 | |||||
| CE2-CZ2-CH2 TRP 117.4 1.0 | |||||
| CZ2-CH2-CZ3 TRP 121.6 1.2 | |||||
| CH2-CZ3-CE3 TRP 121.2 1.1 | |||||
| CZ3-CE3-CD2 TRP 118.8 1.3 | |||||
| N-CA-C TRP 111.0 2.7 | |||||
| CA-C-O TRP 120.1 2.1 | |||||
| N-CA-CB TYR 110.6 1.8 | |||||
| CB-CA-C TYR 110.4 2.0 | |||||
| CA-CB-CG TYR 113.4 1.9 | |||||
| CB-CG-CD1 TYR 121.0 0.6 | |||||
| CB-CG-CD2 TYR 121.0 0.6 | |||||
| CD1-CG-CD2 TYR 117.9 1.1 | |||||
| CG-CD1-CE1 TYR 121.3 0.8 | |||||
| CG-CD2-CE2 TYR 121.3 0.8 | |||||
| CD1-CE1-CZ TYR 119.8 0.9 | |||||
| CD2-CE2-CZ TYR 119.8 0.9 | |||||
| CE1-CZ-CE2 TYR 119.8 1.6 | |||||
| CE1-CZ-OH TYR 120.1 2.7 | |||||
| CE2-CZ-OH TYR 120.1 2.7 | |||||
| N-CA-C TYR 111.0 2.7 | |||||
| CA-C-O TYR 120.1 2.1 | |||||
| N-CA-CB VAL 111.5 2.2 | |||||
| CB-CA-C VAL 111.4 1.9 | |||||
| CA-CB-CG1 VAL 110.9 1.5 | |||||
| CA-CB-CG2 VAL 110.9 1.5 | |||||
| CG1-CB-CG2 VAL 110.9 1.6 | |||||
| N-CA-C VAL 111.0 2.7 | |||||
| CA-C-O VAL 120.1 2.1 | |||||
| - | |||||
| Non-bonded distance Minimum Dist Tolerance | |||||
| C-C 3.4 1.5 | |||||
| C-N 3.25 1.5 | |||||
| C-S 3.5 1.5 | |||||
| C-O 3.22 1.5 | |||||
| N-N 3.1 1.5 | |||||
| N-S 3.35 1.5 | |||||
| N-O 3.07 1.5 | |||||
| O-S 3.32 1.5 | |||||
| O-O 3.04 1.5 | |||||
| S-S 2.03 1.0 | |||||
| - | |||||
| @@ -0,0 +1,161 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| import copy as copy_lib | |||||
| import functools | |||||
| import gzip | |||||
| import pickle | |||||
| from typing import Any, Dict | |||||
| import json | |||||
| import numpy as np | |||||
| from scipy import sparse as sp | |||||
| from . import residue_constants as rc | |||||
| from .data_ops import NumpyDict | |||||
| # from typing import * | |||||
| def lru_cache(maxsize=16, typed=False, copy=False, deepcopy=False): | |||||
| if deepcopy: | |||||
| def decorator(f): | |||||
| cached_func = functools.lru_cache(maxsize, typed)(f) | |||||
| @functools.wraps(f) | |||||
| def wrapper(*args, **kwargs): | |||||
| return copy_lib.deepcopy(cached_func(*args, **kwargs)) | |||||
| return wrapper | |||||
| elif copy: | |||||
| def decorator(f): | |||||
| cached_func = functools.lru_cache(maxsize, typed)(f) | |||||
| @functools.wraps(f) | |||||
| def wrapper(*args, **kwargs): | |||||
| return copy_lib.copy(cached_func(*args, **kwargs)) | |||||
| return wrapper | |||||
| else: | |||||
| decorator = functools.lru_cache(maxsize, typed) | |||||
| return decorator | |||||
| @lru_cache(maxsize=8, deepcopy=True) | |||||
| def load_pickle_safe(path: str) -> Dict[str, Any]: | |||||
| def load(path): | |||||
| assert path.endswith('.pkl') or path.endswith( | |||||
| '.pkl.gz'), f'bad suffix in {path} as pickle file.' | |||||
| open_fn = gzip.open if path.endswith('.gz') else open | |||||
| with open_fn(path, 'rb') as f: | |||||
| return pickle.load(f) | |||||
| ret = load(path) | |||||
| ret = uncompress_features(ret) | |||||
| return ret | |||||
| @lru_cache(maxsize=8, copy=True) | |||||
| def load_pickle(path: str) -> Dict[str, Any]: | |||||
| def load(path): | |||||
| assert path.endswith('.pkl') or path.endswith( | |||||
| '.pkl.gz'), f'bad suffix in {path} as pickle file.' | |||||
| open_fn = gzip.open if path.endswith('.gz') else open | |||||
| with open_fn(path, 'rb') as f: | |||||
| return pickle.load(f) | |||||
| ret = load(path) | |||||
| ret = uncompress_features(ret) | |||||
| return ret | |||||
| def correct_template_restypes(feature): | |||||
| """Correct template restype to have the same order as residue_constants.""" | |||||
| feature = np.argmax(feature, axis=-1).astype(np.int32) | |||||
| new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE | |||||
| feature = np.take(new_order_list, feature.astype(np.int32), axis=0) | |||||
| return feature | |||||
| def convert_all_seq_feature(feature: NumpyDict) -> NumpyDict: | |||||
| feature['msa'] = feature['msa'].astype(np.uint8) | |||||
| if 'num_alignments' in feature: | |||||
| feature.pop('num_alignments') | |||||
| # make_all_seq_key = lambda k: f'{k}_all_seq' if not k.endswith('_all_seq') else k | |||||
| def make_all_seq_key(k): | |||||
| if not k.endswith('_all_seq'): | |||||
| return f'{k}_all_seq' | |||||
| return k | |||||
| return {make_all_seq_key(k): v for k, v in feature.items()} | |||||
| def to_dense_matrix(spmat_dict: NumpyDict): | |||||
| spmat = sp.coo_matrix( | |||||
| (spmat_dict['data'], (spmat_dict['row'], spmat_dict['col'])), | |||||
| shape=spmat_dict['shape'], | |||||
| dtype=np.float32, | |||||
| ) | |||||
| return spmat.toarray() | |||||
| FEATS_DTYPE = {'msa': np.int32} | |||||
| def uncompress_features(feats: NumpyDict) -> NumpyDict: | |||||
| if 'sparse_deletion_matrix_int' in feats: | |||||
| v = feats.pop('sparse_deletion_matrix_int') | |||||
| v = to_dense_matrix(v) | |||||
| feats['deletion_matrix'] = v | |||||
| return feats | |||||
| def filter(feature: NumpyDict, **kwargs) -> NumpyDict: | |||||
| assert len(kwargs) == 1, f'wrong usage of filter with kwargs: {kwargs}' | |||||
| if 'desired_keys' in kwargs: | |||||
| feature = { | |||||
| k: v | |||||
| for k, v in feature.items() if k in kwargs['desired_keys'] | |||||
| } | |||||
| elif 'required_keys' in kwargs: | |||||
| for k in kwargs['required_keys']: | |||||
| assert k in feature, f'cannot find required key {k}.' | |||||
| elif 'ignored_keys' in kwargs: | |||||
| feature = { | |||||
| k: v | |||||
| for k, v in feature.items() if k not in kwargs['ignored_keys'] | |||||
| } | |||||
| else: | |||||
| raise AssertionError(f'wrong usage of filter with kwargs: {kwargs}') | |||||
| return feature | |||||
| def compress_features(features: NumpyDict): | |||||
| change_dtype = { | |||||
| 'msa': np.uint8, | |||||
| } | |||||
| sparse_keys = ['deletion_matrix_int'] | |||||
| compressed_features = {} | |||||
| for k, v in features.items(): | |||||
| if k in change_dtype: | |||||
| v = v.astype(change_dtype[k]) | |||||
| if k in sparse_keys: | |||||
| v = sp.coo_matrix(v, dtype=v.dtype) | |||||
| sp_v = { | |||||
| 'shape': v.shape, | |||||
| 'row': v.row, | |||||
| 'col': v.col, | |||||
| 'data': v.data | |||||
| } | |||||
| k = f'sparse_{k}' | |||||
| v = sp_v | |||||
| compressed_features[k] = v | |||||
| return compressed_features | |||||
| @@ -0,0 +1,514 @@ | |||||
| import copy | |||||
| import logging | |||||
| import os | |||||
| # from typing import * | |||||
| from typing import Dict, Iterable, List, Optional, Tuple, Union | |||||
| import json | |||||
| import ml_collections as mlc | |||||
| import numpy as np | |||||
| import torch | |||||
| from unicore.data import UnicoreDataset, data_utils | |||||
| from unicore.distributed import utils as distributed_utils | |||||
| from .data import utils | |||||
| from .data.data_ops import NumpyDict, TorchDict | |||||
| from .data.process import process_features, process_labels | |||||
| from .data.process_multimer import (add_assembly_features, | |||||
| convert_monomer_features, merge_msas, | |||||
| pair_and_merge, post_process) | |||||
| Rotation = Iterable[Iterable] | |||||
| Translation = Iterable | |||||
| Operation = Union[str, Tuple[Rotation, Translation]] | |||||
| NumpyExample = Tuple[NumpyDict, Optional[List[NumpyDict]]] | |||||
| TorchExample = Tuple[TorchDict, Optional[List[TorchDict]]] | |||||
| logger = logging.getLogger(__name__) # pylint: disable=invalid-name | |||||
| def make_data_config( | |||||
| config: mlc.ConfigDict, | |||||
| mode: str, | |||||
| num_res: int, | |||||
| ) -> Tuple[mlc.ConfigDict, List[str]]: | |||||
| cfg = copy.deepcopy(config) | |||||
| mode_cfg = cfg[mode] | |||||
| with cfg.unlocked(): | |||||
| if mode_cfg.crop_size is None: | |||||
| mode_cfg.crop_size = num_res | |||||
| feature_names = cfg.common.unsupervised_features + cfg.common.recycling_features | |||||
| if cfg.common.use_templates: | |||||
| feature_names += cfg.common.template_features | |||||
| if cfg.common.is_multimer: | |||||
| feature_names += cfg.common.multimer_features | |||||
| if cfg[mode].supervised: | |||||
| feature_names += cfg.supervised.supervised_features | |||||
| return cfg, feature_names | |||||
| def process_label(all_atom_positions: np.ndarray, | |||||
| operation: Operation) -> np.ndarray: | |||||
| if operation == 'I': | |||||
| return all_atom_positions | |||||
| rot, trans = operation | |||||
| rot = np.array(rot).reshape(3, 3) | |||||
| trans = np.array(trans).reshape(3) | |||||
| return all_atom_positions @ rot.T + trans | |||||
| @utils.lru_cache(maxsize=8, copy=True) | |||||
| def load_single_feature( | |||||
| sequence_id: str, | |||||
| monomer_feature_dir: str, | |||||
| uniprot_msa_dir: Optional[str] = None, | |||||
| is_monomer: bool = False, | |||||
| ) -> NumpyDict: | |||||
| monomer_feature = utils.load_pickle( | |||||
| os.path.join(monomer_feature_dir, f'{sequence_id}.feature.pkl.gz')) | |||||
| monomer_feature = convert_monomer_features(monomer_feature) | |||||
| chain_feature = {**monomer_feature} | |||||
| if uniprot_msa_dir is not None: | |||||
| all_seq_feature = utils.load_pickle( | |||||
| os.path.join(uniprot_msa_dir, f'{sequence_id}.uniprot.pkl.gz')) | |||||
| if is_monomer: | |||||
| chain_feature['msa'], chain_feature[ | |||||
| 'deletion_matrix'] = merge_msas( | |||||
| chain_feature['msa'], | |||||
| chain_feature['deletion_matrix'], | |||||
| all_seq_feature['msa'], | |||||
| all_seq_feature['deletion_matrix'], | |||||
| ) # noqa | |||||
| else: | |||||
| all_seq_feature = utils.convert_all_seq_feature(all_seq_feature) | |||||
| for key in [ | |||||
| 'msa_all_seq', | |||||
| 'msa_species_identifiers_all_seq', | |||||
| 'deletion_matrix_all_seq', | |||||
| ]: | |||||
| chain_feature[key] = all_seq_feature[key] | |||||
| return chain_feature | |||||
| def load_single_label( | |||||
| label_id: str, | |||||
| label_dir: str, | |||||
| symmetry_operation: Optional[Operation] = None, | |||||
| ) -> NumpyDict: | |||||
| label = utils.load_pickle( | |||||
| os.path.join(label_dir, f'{label_id}.label.pkl.gz')) | |||||
| if symmetry_operation is not None: | |||||
| label['all_atom_positions'] = process_label( | |||||
| label['all_atom_positions'], symmetry_operation) | |||||
| label = { | |||||
| k: v | |||||
| for k, v in label.items() if k in | |||||
| ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] | |||||
| } | |||||
| return label | |||||
| def load( | |||||
| sequence_ids: List[str], | |||||
| monomer_feature_dir: str, | |||||
| uniprot_msa_dir: Optional[str] = None, | |||||
| label_ids: Optional[List[str]] = None, | |||||
| label_dir: Optional[str] = None, | |||||
| symmetry_operations: Optional[List[Operation]] = None, | |||||
| is_monomer: bool = False, | |||||
| ) -> NumpyExample: | |||||
| all_chain_features = [ | |||||
| load_single_feature(s, monomer_feature_dir, uniprot_msa_dir, | |||||
| is_monomer) for s in sequence_ids | |||||
| ] | |||||
| if label_ids is not None: | |||||
| # load labels | |||||
| assert len(label_ids) == len(sequence_ids) | |||||
| assert label_dir is not None | |||||
| if symmetry_operations is None: | |||||
| symmetry_operations = ['I' for _ in label_ids] | |||||
| all_chain_labels = [ | |||||
| load_single_label(ll, label_dir, o) | |||||
| for ll, o in zip(label_ids, symmetry_operations) | |||||
| ] | |||||
| # update labels into features to calculate spatial cropping etc. | |||||
| [f.update(ll) for f, ll in zip(all_chain_features, all_chain_labels)] | |||||
| all_chain_features = add_assembly_features(all_chain_features) | |||||
| # get labels back from features, as add_assembly_features may alter the order of inputs. | |||||
| if label_ids is not None: | |||||
| all_chain_labels = [{ | |||||
| k: f[k] | |||||
| for k in | |||||
| ['aatype', 'all_atom_positions', 'all_atom_mask', 'resolution'] | |||||
| } for f in all_chain_features] | |||||
| else: | |||||
| all_chain_labels = None | |||||
| asym_len = np.array([c['seq_length'] for c in all_chain_features], | |||||
| dtype=np.int64) | |||||
| if is_monomer: | |||||
| all_chain_features = all_chain_features[0] | |||||
| else: | |||||
| all_chain_features = pair_and_merge(all_chain_features) | |||||
| all_chain_features = post_process(all_chain_features) | |||||
| all_chain_features['asym_len'] = asym_len | |||||
| return all_chain_features, all_chain_labels | |||||
| def process( | |||||
| config: mlc.ConfigDict, | |||||
| mode: str, | |||||
| features: NumpyDict, | |||||
| labels: Optional[List[NumpyDict]] = None, | |||||
| seed: int = 0, | |||||
| batch_idx: Optional[int] = None, | |||||
| data_idx: Optional[int] = None, | |||||
| is_distillation: bool = False, | |||||
| ) -> TorchExample: | |||||
| if mode == 'train': | |||||
| assert batch_idx is not None | |||||
| with data_utils.numpy_seed(seed, batch_idx, key='recycling'): | |||||
| num_iters = np.random.randint( | |||||
| 0, config.common.max_recycling_iters + 1) | |||||
| use_clamped_fape = np.random.rand( | |||||
| ) < config[mode].use_clamped_fape_prob | |||||
| else: | |||||
| num_iters = config.common.max_recycling_iters | |||||
| use_clamped_fape = 1 | |||||
| features['num_recycling_iters'] = int(num_iters) | |||||
| features['use_clamped_fape'] = int(use_clamped_fape) | |||||
| features['is_distillation'] = int(is_distillation) | |||||
| if is_distillation and 'msa_chains' in features: | |||||
| features.pop('msa_chains') | |||||
| num_res = int(features['seq_length']) | |||||
| cfg, feature_names = make_data_config(config, mode=mode, num_res=num_res) | |||||
| if labels is not None: | |||||
| features['resolution'] = labels[0]['resolution'].reshape(-1) | |||||
| with data_utils.numpy_seed(seed, data_idx, key='protein_feature'): | |||||
| features['crop_and_fix_size_seed'] = np.random.randint(0, 63355) | |||||
| features = utils.filter(features, desired_keys=feature_names) | |||||
| features = {k: torch.tensor(v) for k, v in features.items()} | |||||
| with torch.no_grad(): | |||||
| features = process_features(features, cfg.common, cfg[mode]) | |||||
| if labels is not None: | |||||
| labels = [{k: torch.tensor(v) for k, v in ll.items()} for ll in labels] | |||||
| with torch.no_grad(): | |||||
| labels = process_labels(labels) | |||||
| return features, labels | |||||
| def load_and_process( | |||||
| config: mlc.ConfigDict, | |||||
| mode: str, | |||||
| seed: int = 0, | |||||
| batch_idx: Optional[int] = None, | |||||
| data_idx: Optional[int] = None, | |||||
| is_distillation: bool = False, | |||||
| **load_kwargs, | |||||
| ): | |||||
| is_monomer = ( | |||||
| is_distillation | |||||
| if 'is_monomer' not in load_kwargs else load_kwargs.pop('is_monomer')) | |||||
| features, labels = load(**load_kwargs, is_monomer=is_monomer) | |||||
| features, labels = process(config, mode, features, labels, seed, batch_idx, | |||||
| data_idx, is_distillation) | |||||
| return features, labels | |||||
| class UnifoldDataset(UnicoreDataset): | |||||
| def __init__( | |||||
| self, | |||||
| args, | |||||
| seed, | |||||
| config, | |||||
| data_path, | |||||
| mode='train', | |||||
| max_step=None, | |||||
| disable_sd=False, | |||||
| json_prefix='', | |||||
| ): | |||||
| self.path = data_path | |||||
| def load_json(filename): | |||||
| return json.load(open(filename, 'r')) | |||||
| sample_weight = load_json( | |||||
| os.path.join(self.path, | |||||
| json_prefix + mode + '_sample_weight.json')) | |||||
| self.multi_label = load_json( | |||||
| os.path.join(self.path, json_prefix + mode + '_multi_label.json')) | |||||
| self.inverse_multi_label = self._inverse_map(self.multi_label) | |||||
| self.sample_weight = {} | |||||
| for chain in self.inverse_multi_label: | |||||
| entity = self.inverse_multi_label[chain] | |||||
| self.sample_weight[chain] = sample_weight[entity] | |||||
| self.seq_sample_weight = sample_weight | |||||
| logger.info('load {} chains (unique {} sequences)'.format( | |||||
| len(self.sample_weight), len(self.seq_sample_weight))) | |||||
| self.feature_path = os.path.join(self.path, 'pdb_features') | |||||
| self.label_path = os.path.join(self.path, 'pdb_labels') | |||||
| sd_sample_weight_path = os.path.join( | |||||
| self.path, json_prefix + 'sd_train_sample_weight.json') | |||||
| if mode == 'train' and os.path.isfile( | |||||
| sd_sample_weight_path) and not disable_sd: | |||||
| self.sd_sample_weight = load_json(sd_sample_weight_path) | |||||
| logger.info('load {} self-distillation samples.'.format( | |||||
| len(self.sd_sample_weight))) | |||||
| self.sd_feature_path = os.path.join(self.path, 'sd_features') | |||||
| self.sd_label_path = os.path.join(self.path, 'sd_labels') | |||||
| else: | |||||
| self.sd_sample_weight = None | |||||
| self.batch_size = ( | |||||
| args.batch_size * distributed_utils.get_data_parallel_world_size() | |||||
| * args.update_freq[0]) | |||||
| self.data_len = ( | |||||
| max_step * self.batch_size | |||||
| if max_step is not None else len(self.sample_weight)) | |||||
| self.mode = mode | |||||
| self.num_seq, self.seq_keys, self.seq_sample_prob = self.cal_sample_weight( | |||||
| self.seq_sample_weight) | |||||
| self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( | |||||
| self.sample_weight) | |||||
| if self.sd_sample_weight is not None: | |||||
| ( | |||||
| self.sd_num_chain, | |||||
| self.sd_chain_keys, | |||||
| self.sd_sample_prob, | |||||
| ) = self.cal_sample_weight(self.sd_sample_weight) | |||||
| self.config = config.data | |||||
| self.seed = seed | |||||
| self.sd_prob = args.sd_prob | |||||
| def cal_sample_weight(self, sample_weight): | |||||
| prot_keys = list(sample_weight.keys()) | |||||
| sum_weight = sum(sample_weight.values()) | |||||
| sample_prob = [sample_weight[k] / sum_weight for k in prot_keys] | |||||
| num_prot = len(prot_keys) | |||||
| return num_prot, prot_keys, sample_prob | |||||
| def sample_chain(self, idx, sample_by_seq=False): | |||||
| is_distillation = False | |||||
| if self.mode == 'train': | |||||
| with data_utils.numpy_seed(self.seed, idx, key='data_sample'): | |||||
| is_distillation = ((np.random.rand(1)[0] < self.sd_prob) | |||||
| if self.sd_sample_weight is not None else | |||||
| False) | |||||
| if is_distillation: | |||||
| prot_idx = np.random.choice( | |||||
| self.sd_num_chain, p=self.sd_sample_prob) | |||||
| label_name = self.sd_chain_keys[prot_idx] | |||||
| seq_name = label_name | |||||
| else: | |||||
| if not sample_by_seq: | |||||
| prot_idx = np.random.choice( | |||||
| self.num_chain, p=self.sample_prob) | |||||
| label_name = self.chain_keys[prot_idx] | |||||
| seq_name = self.inverse_multi_label[label_name] | |||||
| else: | |||||
| seq_idx = np.random.choice( | |||||
| self.num_seq, p=self.seq_sample_prob) | |||||
| seq_name = self.seq_keys[seq_idx] | |||||
| label_name = np.random.choice( | |||||
| self.multi_label[seq_name]) | |||||
| else: | |||||
| label_name = self.chain_keys[idx] | |||||
| seq_name = self.inverse_multi_label[label_name] | |||||
| return seq_name, label_name, is_distillation | |||||
| def __getitem__(self, idx): | |||||
| sequence_id, label_id, is_distillation = self.sample_chain( | |||||
| idx, sample_by_seq=True) | |||||
| feature_dir, label_dir = ((self.feature_path, | |||||
| self.label_path) if not is_distillation else | |||||
| (self.sd_feature_path, self.sd_label_path)) | |||||
| features, _ = load_and_process( | |||||
| self.config, | |||||
| self.mode, | |||||
| self.seed, | |||||
| batch_idx=(idx // self.batch_size), | |||||
| data_idx=idx, | |||||
| is_distillation=is_distillation, | |||||
| sequence_ids=[sequence_id], | |||||
| monomer_feature_dir=feature_dir, | |||||
| uniprot_msa_dir=None, | |||||
| label_ids=[label_id], | |||||
| label_dir=label_dir, | |||||
| symmetry_operations=None, | |||||
| is_monomer=True, | |||||
| ) | |||||
| return features | |||||
| def __len__(self): | |||||
| return self.data_len | |||||
| @staticmethod | |||||
| def collater(samples): | |||||
| # first dim is recyling. bsz is at the 2nd dim | |||||
| return data_utils.collate_dict(samples, dim=1) | |||||
| @staticmethod | |||||
| def _inverse_map(mapping: Dict[str, List[str]]): | |||||
| inverse_mapping = {} | |||||
| for ent, refs in mapping.items(): | |||||
| for ref in refs: | |||||
| if ref in inverse_mapping: # duplicated ent for this ref. | |||||
| ent_2 = inverse_mapping[ref] | |||||
| assert ( | |||||
| ent == ent_2 | |||||
| ), f'multiple entities ({ent_2}, {ent}) exist for reference {ref}.' | |||||
| inverse_mapping[ref] = ent | |||||
| return inverse_mapping | |||||
| class UnifoldMultimerDataset(UnifoldDataset): | |||||
| def __init__( | |||||
| self, | |||||
| args: mlc.ConfigDict, | |||||
| seed: int, | |||||
| config: mlc.ConfigDict, | |||||
| data_path: str, | |||||
| mode: str = 'train', | |||||
| max_step: Optional[int] = None, | |||||
| disable_sd: bool = False, | |||||
| json_prefix: str = '', | |||||
| **kwargs, | |||||
| ): | |||||
| super().__init__(args, seed, config, data_path, mode, max_step, | |||||
| disable_sd, json_prefix) | |||||
| self.data_path = data_path | |||||
| self.pdb_assembly = json.load( | |||||
| open( | |||||
| os.path.join(self.data_path, | |||||
| json_prefix + 'pdb_assembly.json'))) | |||||
| self.pdb_chains = self.get_chains(self.inverse_multi_label) | |||||
| self.monomer_feature_path = os.path.join(self.data_path, | |||||
| 'pdb_features') | |||||
| self.uniprot_msa_path = os.path.join(self.data_path, 'pdb_uniprots') | |||||
| self.label_path = os.path.join(self.data_path, 'pdb_labels') | |||||
| self.max_chains = args.max_chains | |||||
| if self.mode == 'train': | |||||
| self.pdb_chains, self.sample_weight = self.filter_pdb_by_max_chains( | |||||
| self.pdb_chains, self.pdb_assembly, self.sample_weight, | |||||
| self.max_chains) | |||||
| self.num_chain, self.chain_keys, self.sample_prob = self.cal_sample_weight( | |||||
| self.sample_weight) | |||||
| def __getitem__(self, idx): | |||||
| seq_id, label_id, is_distillation = self.sample_chain(idx) | |||||
| if is_distillation: | |||||
| label_ids = [label_id] | |||||
| sequence_ids = [seq_id] | |||||
| monomer_feature_path, uniprot_msa_path, label_path = ( | |||||
| self.sd_feature_path, | |||||
| None, | |||||
| self.sd_label_path, | |||||
| ) | |||||
| symmetry_operations = None | |||||
| else: | |||||
| pdb_id = self.get_pdb_name(label_id) | |||||
| if pdb_id in self.pdb_assembly and self.mode == 'train': | |||||
| label_ids = [ | |||||
| pdb_id + '_' + id | |||||
| for id in self.pdb_assembly[pdb_id]['chains'] | |||||
| ] | |||||
| symmetry_operations = [ | |||||
| t for t in self.pdb_assembly[pdb_id]['opers'] | |||||
| ] | |||||
| else: | |||||
| label_ids = self.pdb_chains[pdb_id] | |||||
| symmetry_operations = None | |||||
| sequence_ids = [ | |||||
| self.inverse_multi_label[chain_id] for chain_id in label_ids | |||||
| ] | |||||
| monomer_feature_path, uniprot_msa_path, label_path = ( | |||||
| self.monomer_feature_path, | |||||
| self.uniprot_msa_path, | |||||
| self.label_path, | |||||
| ) | |||||
| return load_and_process( | |||||
| self.config, | |||||
| self.mode, | |||||
| self.seed, | |||||
| batch_idx=(idx // self.batch_size), | |||||
| data_idx=idx, | |||||
| is_distillation=is_distillation, | |||||
| sequence_ids=sequence_ids, | |||||
| monomer_feature_dir=monomer_feature_path, | |||||
| uniprot_msa_dir=uniprot_msa_path, | |||||
| label_ids=label_ids, | |||||
| label_dir=label_path, | |||||
| symmetry_operations=symmetry_operations, | |||||
| is_monomer=False, | |||||
| ) | |||||
| @staticmethod | |||||
| def collater(samples): | |||||
| # first dim is recyling. bsz is at the 2nd dim | |||||
| if len(samples) <= 0: # tackle empty batch | |||||
| return None | |||||
| feats = [s[0] for s in samples] | |||||
| labs = [s[1] for s in samples if s[1] is not None] | |||||
| try: | |||||
| feats = data_utils.collate_dict(feats, dim=1) | |||||
| except BaseException: | |||||
| raise ValueError('cannot collate features', feats) | |||||
| if not labs: | |||||
| labs = None | |||||
| return feats, labs | |||||
| @staticmethod | |||||
| def get_pdb_name(chain): | |||||
| return chain.split('_')[0] | |||||
| @staticmethod | |||||
| def get_chains(canon_chain_map): | |||||
| pdb_chains = {} | |||||
| for chain in canon_chain_map: | |||||
| pdb = UnifoldMultimerDataset.get_pdb_name(chain) | |||||
| if pdb not in pdb_chains: | |||||
| pdb_chains[pdb] = [] | |||||
| pdb_chains[pdb].append(chain) | |||||
| return pdb_chains | |||||
| @staticmethod | |||||
| def filter_pdb_by_max_chains(pdb_chains, pdb_assembly, sample_weight, | |||||
| max_chains): | |||||
| new_pdb_chains = {} | |||||
| for chain in pdb_chains: | |||||
| if chain in pdb_assembly: | |||||
| size = len(pdb_assembly[chain]['chains']) | |||||
| if size <= max_chains: | |||||
| new_pdb_chains[chain] = pdb_chains[chain] | |||||
| else: | |||||
| size = len(pdb_chains[chain]) | |||||
| if size == 1: | |||||
| new_pdb_chains[chain] = pdb_chains[chain] | |||||
| new_sample_weight = { | |||||
| k: sample_weight[k] | |||||
| for k in sample_weight | |||||
| if UnifoldMultimerDataset.get_pdb_name(k) in new_pdb_chains | |||||
| } | |||||
| logger.info( | |||||
| f'filtered out {len(pdb_chains) - len(new_pdb_chains)} / {len(pdb_chains)} PDBs ' | |||||
| f'({len(sample_weight) - len(new_sample_weight)} / {len(sample_weight)} chains) ' | |||||
| f'by max_chains {max_chains}') | |||||
| return new_pdb_chains, new_sample_weight | |||||
| @@ -0,0 +1,75 @@ | |||||
| import argparse | |||||
| import os | |||||
| from typing import Any | |||||
| import torch | |||||
| from modelscope.metainfo import Models | |||||
| from modelscope.models import TorchModel | |||||
| from modelscope.models.builder import MODELS | |||||
| from modelscope.utils.constant import ModelFile, Tasks | |||||
| from .config import model_config | |||||
| from .modules.alphafold import AlphaFold | |||||
| __all__ = ['UnifoldForProteinStructrue'] | |||||
| @MODELS.register_module(Tasks.protein_structure, module_name=Models.unifold) | |||||
| class UnifoldForProteinStructrue(TorchModel): | |||||
| @staticmethod | |||||
| def add_args(parser): | |||||
| """Add model-specific arguments to the parser.""" | |||||
| parser.add_argument( | |||||
| '--model-name', | |||||
| help='choose the model config', | |||||
| ) | |||||
| def __init__(self, **kwargs): | |||||
| super().__init__() | |||||
| parser = argparse.ArgumentParser() | |||||
| parse_comm = [] | |||||
| for key in kwargs: | |||||
| parser.add_argument(f'--{key}') | |||||
| parse_comm.append(f'--{key}') | |||||
| parse_comm.append(kwargs[key]) | |||||
| args = parser.parse_args(parse_comm) | |||||
| base_architecture(args) | |||||
| self.args = args | |||||
| config = model_config( | |||||
| self.args.model_name, | |||||
| train=True, | |||||
| ) | |||||
| self.model = AlphaFold(config) | |||||
| self.config = config | |||||
| # load model state dict | |||||
| param_path = os.path.join(kwargs['model_dir'], | |||||
| ModelFile.TORCH_MODEL_BIN_FILE) | |||||
| state_dict = torch.load(param_path)['ema']['params'] | |||||
| state_dict = { | |||||
| '.'.join(k.split('.')[1:]): v | |||||
| for k, v in state_dict.items() | |||||
| } | |||||
| self.model.load_state_dict(state_dict) | |||||
| def half(self): | |||||
| self.model = self.model.half() | |||||
| return self | |||||
| def bfloat16(self): | |||||
| self.model = self.model.bfloat16() | |||||
| return self | |||||
| @classmethod | |||||
| def build_model(cls, args, task): | |||||
| """Build a new model instance.""" | |||||
| return cls(args) | |||||
| def forward(self, batch, **kwargs): | |||||
| outputs = self.model.forward(batch) | |||||
| return outputs, self.config.loss | |||||
| def base_architecture(args): | |||||
| args.model_name = getattr(args, 'model_name', 'model_2') | |||||
| @@ -0,0 +1,450 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.utils import tensor_tree_map | |||||
| from ..data import residue_constants | |||||
| from .attentions import gen_msa_attn_mask, gen_tri_attn_mask | |||||
| from .auxillary_heads import AuxiliaryHeads | |||||
| from .common import residual | |||||
| from .embedders import (ExtraMSAEmbedder, InputEmbedder, RecyclingEmbedder, | |||||
| TemplateAngleEmbedder, TemplatePairEmbedder) | |||||
| from .evoformer import EvoformerStack, ExtraMSAStack | |||||
| from .featurization import (atom14_to_atom37, build_extra_msa_feat, | |||||
| build_template_angle_feat, | |||||
| build_template_pair_feat, | |||||
| build_template_pair_feat_v2, pseudo_beta_fn) | |||||
| from .structure_module import StructureModule | |||||
| from .template import (TemplatePairStack, TemplatePointwiseAttention, | |||||
| TemplateProjection) | |||||
| class AlphaFold(nn.Module): | |||||
| def __init__(self, config): | |||||
| super(AlphaFold, self).__init__() | |||||
| self.globals = config.globals | |||||
| config = config.model | |||||
| template_config = config.template | |||||
| extra_msa_config = config.extra_msa | |||||
| self.input_embedder = InputEmbedder( | |||||
| **config['input_embedder'], | |||||
| use_chain_relative=config.is_multimer, | |||||
| ) | |||||
| self.recycling_embedder = RecyclingEmbedder( | |||||
| **config['recycling_embedder'], ) | |||||
| if config.template.enabled: | |||||
| self.template_angle_embedder = TemplateAngleEmbedder( | |||||
| **template_config['template_angle_embedder'], ) | |||||
| self.template_pair_embedder = TemplatePairEmbedder( | |||||
| **template_config['template_pair_embedder'], ) | |||||
| self.template_pair_stack = TemplatePairStack( | |||||
| **template_config['template_pair_stack'], ) | |||||
| else: | |||||
| self.template_pair_stack = None | |||||
| self.enable_template_pointwise_attention = template_config[ | |||||
| 'template_pointwise_attention'].enabled | |||||
| if self.enable_template_pointwise_attention: | |||||
| self.template_pointwise_att = TemplatePointwiseAttention( | |||||
| **template_config['template_pointwise_attention'], ) | |||||
| else: | |||||
| self.template_proj = TemplateProjection( | |||||
| **template_config['template_pointwise_attention'], ) | |||||
| self.extra_msa_embedder = ExtraMSAEmbedder( | |||||
| **extra_msa_config['extra_msa_embedder'], ) | |||||
| self.extra_msa_stack = ExtraMSAStack( | |||||
| **extra_msa_config['extra_msa_stack'], ) | |||||
| self.evoformer = EvoformerStack(**config['evoformer_stack'], ) | |||||
| self.structure_module = StructureModule(**config['structure_module'], ) | |||||
| self.aux_heads = AuxiliaryHeads(config['heads'], ) | |||||
| self.config = config | |||||
| self.dtype = torch.float | |||||
| self.inf = self.globals.inf | |||||
| if self.globals.alphafold_original_mode: | |||||
| self.alphafold_original_mode() | |||||
| def __make_input_float__(self): | |||||
| self.input_embedder = self.input_embedder.float() | |||||
| self.recycling_embedder = self.recycling_embedder.float() | |||||
| def half(self): | |||||
| super().half() | |||||
| if (not getattr(self, 'inference', False)): | |||||
| self.__make_input_float__() | |||||
| self.dtype = torch.half | |||||
| return self | |||||
| def bfloat16(self): | |||||
| super().bfloat16() | |||||
| if (not getattr(self, 'inference', False)): | |||||
| self.__make_input_float__() | |||||
| self.dtype = torch.bfloat16 | |||||
| return self | |||||
| def alphafold_original_mode(self): | |||||
| def set_alphafold_original_mode(module): | |||||
| if hasattr(module, 'apply_alphafold_original_mode'): | |||||
| module.apply_alphafold_original_mode() | |||||
| if hasattr(module, 'act'): | |||||
| module.act = nn.ReLU() | |||||
| self.apply(set_alphafold_original_mode) | |||||
| def inference_mode(self): | |||||
| def set_inference_mode(module): | |||||
| setattr(module, 'inference', True) | |||||
| self.apply(set_inference_mode) | |||||
| def __convert_input_dtype__(self, batch): | |||||
| for key in batch: | |||||
| # only convert features with mask | |||||
| if batch[key].dtype != self.dtype and 'mask' in key: | |||||
| batch[key] = batch[key].type(self.dtype) | |||||
| return batch | |||||
| def embed_templates_pair_core(self, batch, z, pair_mask, | |||||
| tri_start_attn_mask, tri_end_attn_mask, | |||||
| templ_dim, multichain_mask_2d): | |||||
| if self.config.template.template_pair_embedder.v2_feature: | |||||
| t = build_template_pair_feat_v2( | |||||
| batch, | |||||
| inf=self.config.template.inf, | |||||
| eps=self.config.template.eps, | |||||
| multichain_mask_2d=multichain_mask_2d, | |||||
| **self.config.template.distogram, | |||||
| ) | |||||
| num_template = t[0].shape[-4] | |||||
| single_templates = [ | |||||
| self.template_pair_embedder([x[..., ti, :, :, :] | |||||
| for x in t], z) | |||||
| for ti in range(num_template) | |||||
| ] | |||||
| else: | |||||
| t = build_template_pair_feat( | |||||
| batch, | |||||
| inf=self.config.template.inf, | |||||
| eps=self.config.template.eps, | |||||
| **self.config.template.distogram, | |||||
| ) | |||||
| single_templates = [ | |||||
| self.template_pair_embedder(x, z) | |||||
| for x in torch.unbind(t, dim=templ_dim) | |||||
| ] | |||||
| t = self.template_pair_stack( | |||||
| single_templates, | |||||
| pair_mask, | |||||
| tri_start_attn_mask=tri_start_attn_mask, | |||||
| tri_end_attn_mask=tri_end_attn_mask, | |||||
| templ_dim=templ_dim, | |||||
| chunk_size=self.globals.chunk_size, | |||||
| block_size=self.globals.block_size, | |||||
| return_mean=not self.enable_template_pointwise_attention, | |||||
| ) | |||||
| return t | |||||
| def embed_templates_pair(self, batch, z, pair_mask, tri_start_attn_mask, | |||||
| tri_end_attn_mask, templ_dim): | |||||
| if self.config.template.template_pair_embedder.v2_feature and 'asym_id' in batch: | |||||
| multichain_mask_2d = ( | |||||
| batch['asym_id'][..., :, None] == batch['asym_id'][..., | |||||
| None, :]) | |||||
| multichain_mask_2d = multichain_mask_2d.unsqueeze(0) | |||||
| else: | |||||
| multichain_mask_2d = None | |||||
| if self.training or self.enable_template_pointwise_attention: | |||||
| t = self.embed_templates_pair_core(batch, z, pair_mask, | |||||
| tri_start_attn_mask, | |||||
| tri_end_attn_mask, templ_dim, | |||||
| multichain_mask_2d) | |||||
| if self.enable_template_pointwise_attention: | |||||
| t = self.template_pointwise_att( | |||||
| t, | |||||
| z, | |||||
| template_mask=batch['template_mask'], | |||||
| chunk_size=self.globals.chunk_size, | |||||
| ) | |||||
| t_mask = torch.sum( | |||||
| batch['template_mask'], dim=-1, keepdims=True) > 0 | |||||
| t_mask = t_mask[..., None, None].type(t.dtype) | |||||
| t *= t_mask | |||||
| else: | |||||
| t = self.template_proj(t, z) | |||||
| else: | |||||
| template_aatype_shape = batch['template_aatype'].shape | |||||
| # template_aatype is either [n_template, n_res] or [1, n_template_, n_res] | |||||
| batch_templ_dim = 1 if len(template_aatype_shape) == 3 else 0 | |||||
| n_templ = batch['template_aatype'].shape[batch_templ_dim] | |||||
| if n_templ <= 0: | |||||
| t = None | |||||
| else: | |||||
| template_batch = { | |||||
| k: v | |||||
| for k, v in batch.items() if k.startswith('template_') | |||||
| } | |||||
| def embed_one_template(i): | |||||
| def slice_template_tensor(t): | |||||
| s = [slice(None) for _ in t.shape] | |||||
| s[batch_templ_dim] = slice(i, i + 1) | |||||
| return t[s] | |||||
| template_feats = tensor_tree_map( | |||||
| slice_template_tensor, | |||||
| template_batch, | |||||
| ) | |||||
| t = self.embed_templates_pair_core( | |||||
| template_feats, z, pair_mask, tri_start_attn_mask, | |||||
| tri_end_attn_mask, templ_dim, multichain_mask_2d) | |||||
| return t | |||||
| t = embed_one_template(0) | |||||
| # iterate templates one by one | |||||
| for i in range(1, n_templ): | |||||
| t += embed_one_template(i) | |||||
| t /= n_templ | |||||
| t = self.template_proj(t, z) | |||||
| return t | |||||
| def embed_templates_angle(self, batch): | |||||
| template_angle_feat, template_angle_mask = build_template_angle_feat( | |||||
| batch, | |||||
| v2_feature=self.config.template.template_pair_embedder.v2_feature) | |||||
| t = self.template_angle_embedder(template_angle_feat) | |||||
| return t, template_angle_mask | |||||
| def iteration_evoformer(self, feats, m_1_prev, z_prev, x_prev): | |||||
| batch_dims = feats['target_feat'].shape[:-2] | |||||
| n = feats['target_feat'].shape[-2] | |||||
| seq_mask = feats['seq_mask'] | |||||
| pair_mask = seq_mask[..., None] * seq_mask[..., None, :] | |||||
| msa_mask = feats['msa_mask'] | |||||
| m, z = self.input_embedder( | |||||
| feats['target_feat'], | |||||
| feats['msa_feat'], | |||||
| ) | |||||
| if m_1_prev is None: | |||||
| m_1_prev = m.new_zeros( | |||||
| (*batch_dims, n, self.config.input_embedder.d_msa), | |||||
| requires_grad=False, | |||||
| ) | |||||
| if z_prev is None: | |||||
| z_prev = z.new_zeros( | |||||
| (*batch_dims, n, n, self.config.input_embedder.d_pair), | |||||
| requires_grad=False, | |||||
| ) | |||||
| if x_prev is None: | |||||
| x_prev = z.new_zeros( | |||||
| (*batch_dims, n, residue_constants.atom_type_num, 3), | |||||
| requires_grad=False, | |||||
| ) | |||||
| x_prev = pseudo_beta_fn(feats['aatype'], x_prev, None) | |||||
| z += self.recycling_embedder.recyle_pos(x_prev) | |||||
| m_1_prev_emb, z_prev_emb = self.recycling_embedder( | |||||
| m_1_prev, | |||||
| z_prev, | |||||
| ) | |||||
| m[..., 0, :, :] += m_1_prev_emb | |||||
| z += z_prev_emb | |||||
| z += self.input_embedder.relpos_emb( | |||||
| feats['residue_index'].long(), | |||||
| feats.get('sym_id', None), | |||||
| feats.get('asym_id', None), | |||||
| feats.get('entity_id', None), | |||||
| feats.get('num_sym', None), | |||||
| ) | |||||
| m = m.type(self.dtype) | |||||
| z = z.type(self.dtype) | |||||
| tri_start_attn_mask, tri_end_attn_mask = gen_tri_attn_mask( | |||||
| pair_mask, self.inf) | |||||
| if self.config.template.enabled: | |||||
| template_mask = feats['template_mask'] | |||||
| if torch.any(template_mask): | |||||
| z = residual( | |||||
| z, | |||||
| self.embed_templates_pair( | |||||
| feats, | |||||
| z, | |||||
| pair_mask, | |||||
| tri_start_attn_mask, | |||||
| tri_end_attn_mask, | |||||
| templ_dim=-4, | |||||
| ), | |||||
| self.training, | |||||
| ) | |||||
| if self.config.extra_msa.enabled: | |||||
| a = self.extra_msa_embedder(build_extra_msa_feat(feats)) | |||||
| extra_msa_row_mask = gen_msa_attn_mask( | |||||
| feats['extra_msa_mask'], | |||||
| inf=self.inf, | |||||
| gen_col_mask=False, | |||||
| ) | |||||
| z = self.extra_msa_stack( | |||||
| a, | |||||
| z, | |||||
| msa_mask=feats['extra_msa_mask'], | |||||
| chunk_size=self.globals.chunk_size, | |||||
| block_size=self.globals.block_size, | |||||
| pair_mask=pair_mask, | |||||
| msa_row_attn_mask=extra_msa_row_mask, | |||||
| msa_col_attn_mask=None, | |||||
| tri_start_attn_mask=tri_start_attn_mask, | |||||
| tri_end_attn_mask=tri_end_attn_mask, | |||||
| ) | |||||
| if self.config.template.embed_angles: | |||||
| template_1d_feat, template_1d_mask = self.embed_templates_angle( | |||||
| feats) | |||||
| m = torch.cat([m, template_1d_feat], dim=-3) | |||||
| msa_mask = torch.cat([feats['msa_mask'], template_1d_mask], dim=-2) | |||||
| msa_row_mask, msa_col_mask = gen_msa_attn_mask( | |||||
| msa_mask, | |||||
| inf=self.inf, | |||||
| ) | |||||
| m, z, s = self.evoformer( | |||||
| m, | |||||
| z, | |||||
| msa_mask=msa_mask, | |||||
| pair_mask=pair_mask, | |||||
| msa_row_attn_mask=msa_row_mask, | |||||
| msa_col_attn_mask=msa_col_mask, | |||||
| tri_start_attn_mask=tri_start_attn_mask, | |||||
| tri_end_attn_mask=tri_end_attn_mask, | |||||
| chunk_size=self.globals.chunk_size, | |||||
| block_size=self.globals.block_size, | |||||
| ) | |||||
| return m, z, s, msa_mask, m_1_prev_emb, z_prev_emb | |||||
| def iteration_evoformer_structure_module(self, | |||||
| batch, | |||||
| m_1_prev, | |||||
| z_prev, | |||||
| x_prev, | |||||
| cycle_no, | |||||
| num_recycling, | |||||
| num_ensembles=1): | |||||
| z, s = 0, 0 | |||||
| n_seq = batch['msa_feat'].shape[-3] | |||||
| assert num_ensembles >= 1 | |||||
| for ensemble_no in range(num_ensembles): | |||||
| idx = cycle_no * num_ensembles + ensemble_no | |||||
| # fetch_cur_batch = lambda t: t[min(t.shape[0] - 1, idx), ...] | |||||
| def fetch_cur_batch(t): | |||||
| return t[min(t.shape[0] - 1, idx), ...] | |||||
| feats = tensor_tree_map(fetch_cur_batch, batch) | |||||
| m, z0, s0, msa_mask, m_1_prev_emb, z_prev_emb = self.iteration_evoformer( | |||||
| feats, m_1_prev, z_prev, x_prev) | |||||
| z += z0 | |||||
| s += s0 | |||||
| del z0, s0 | |||||
| if num_ensembles > 1: | |||||
| z /= float(num_ensembles) | |||||
| s /= float(num_ensembles) | |||||
| outputs = {} | |||||
| outputs['msa'] = m[..., :n_seq, :, :] | |||||
| outputs['pair'] = z | |||||
| outputs['single'] = s | |||||
| # norm loss | |||||
| if (not getattr(self, 'inference', | |||||
| False)) and num_recycling == (cycle_no + 1): | |||||
| delta_msa = m | |||||
| delta_msa[..., | |||||
| 0, :, :] = delta_msa[..., | |||||
| 0, :, :] - m_1_prev_emb.detach() | |||||
| delta_pair = z - z_prev_emb.detach() | |||||
| outputs['delta_msa'] = delta_msa | |||||
| outputs['delta_pair'] = delta_pair | |||||
| outputs['msa_norm_mask'] = msa_mask | |||||
| outputs['sm'] = self.structure_module( | |||||
| s, | |||||
| z, | |||||
| feats['aatype'], | |||||
| mask=feats['seq_mask'], | |||||
| ) | |||||
| outputs['final_atom_positions'] = atom14_to_atom37( | |||||
| outputs['sm']['positions'], feats) | |||||
| outputs['final_atom_mask'] = feats['atom37_atom_exists'] | |||||
| outputs['pred_frame_tensor'] = outputs['sm']['frames'][-1] | |||||
| # use float32 for numerical stability | |||||
| if (not getattr(self, 'inference', False)): | |||||
| m_1_prev = m[..., 0, :, :].float() | |||||
| z_prev = z.float() | |||||
| x_prev = outputs['final_atom_positions'].float() | |||||
| else: | |||||
| m_1_prev = m[..., 0, :, :] | |||||
| z_prev = z | |||||
| x_prev = outputs['final_atom_positions'] | |||||
| return outputs, m_1_prev, z_prev, x_prev | |||||
| def forward(self, batch): | |||||
| m_1_prev = batch.get('m_1_prev', None) | |||||
| z_prev = batch.get('z_prev', None) | |||||
| x_prev = batch.get('x_prev', None) | |||||
| is_grad_enabled = torch.is_grad_enabled() | |||||
| num_iters = int(batch['num_recycling_iters']) + 1 | |||||
| num_ensembles = int(batch['msa_mask'].shape[0]) // num_iters | |||||
| if self.training: | |||||
| # don't use ensemble during training | |||||
| assert num_ensembles == 1 | |||||
| # convert dtypes in batch | |||||
| batch = self.__convert_input_dtype__(batch) | |||||
| for cycle_no in range(num_iters): | |||||
| is_final_iter = cycle_no == (num_iters - 1) | |||||
| with torch.set_grad_enabled(is_grad_enabled and is_final_iter): | |||||
| ( | |||||
| outputs, | |||||
| m_1_prev, | |||||
| z_prev, | |||||
| x_prev, | |||||
| ) = self.iteration_evoformer_structure_module( | |||||
| batch, | |||||
| m_1_prev, | |||||
| z_prev, | |||||
| x_prev, | |||||
| cycle_no=cycle_no, | |||||
| num_recycling=num_iters, | |||||
| num_ensembles=num_ensembles, | |||||
| ) | |||||
| if not is_final_iter: | |||||
| del outputs | |||||
| if 'asym_id' in batch: | |||||
| outputs['asym_id'] = batch['asym_id'][0, ...] | |||||
| outputs.update(self.aux_heads(outputs)) | |||||
| return outputs | |||||
| @@ -0,0 +1,430 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from functools import partialmethod | |||||
| from typing import List, Optional | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.modules import LayerNorm, softmax_dropout | |||||
| from unicore.utils import permute_final_dims | |||||
| from .common import Linear, chunk_layer | |||||
| def gen_attn_mask(mask, neg_inf): | |||||
| assert neg_inf < -1e4 | |||||
| attn_mask = torch.zeros_like(mask) | |||||
| attn_mask[mask == 0] = neg_inf | |||||
| return attn_mask | |||||
| class Attention(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| q_dim: int, | |||||
| k_dim: int, | |||||
| v_dim: int, | |||||
| head_dim: int, | |||||
| num_heads: int, | |||||
| gating: bool = True, | |||||
| ): | |||||
| super(Attention, self).__init__() | |||||
| self.num_heads = num_heads | |||||
| total_dim = head_dim * self.num_heads | |||||
| self.gating = gating | |||||
| self.linear_q = Linear(q_dim, total_dim, bias=False, init='glorot') | |||||
| self.linear_k = Linear(k_dim, total_dim, bias=False, init='glorot') | |||||
| self.linear_v = Linear(v_dim, total_dim, bias=False, init='glorot') | |||||
| self.linear_o = Linear(total_dim, q_dim, init='final') | |||||
| self.linear_g = None | |||||
| if self.gating: | |||||
| self.linear_g = Linear(q_dim, total_dim, init='gating') | |||||
| # precompute the 1/sqrt(head_dim) | |||||
| self.norm = head_dim**-0.5 | |||||
| def forward( | |||||
| self, | |||||
| q: torch.Tensor, | |||||
| k: torch.Tensor, | |||||
| v: torch.Tensor, | |||||
| mask: torch.Tensor = None, | |||||
| bias: Optional[torch.Tensor] = None, | |||||
| ) -> torch.Tensor: | |||||
| g = None | |||||
| if self.linear_g is not None: | |||||
| # gating, use raw query input | |||||
| g = self.linear_g(q) | |||||
| q = self.linear_q(q) | |||||
| q *= self.norm | |||||
| k = self.linear_k(k) | |||||
| v = self.linear_v(v) | |||||
| q = q.view(q.shape[:-1] + (self.num_heads, -1)).transpose( | |||||
| -2, -3).contiguous() | |||||
| k = k.view(k.shape[:-1] + (self.num_heads, -1)).transpose( | |||||
| -2, -3).contiguous() | |||||
| v = v.view(v.shape[:-1] + (self.num_heads, -1)).transpose(-2, -3) | |||||
| attn = torch.matmul(q, k.transpose(-1, -2)) | |||||
| del q, k | |||||
| attn = softmax_dropout(attn, 0, self.training, mask=mask, bias=bias) | |||||
| o = torch.matmul(attn, v) | |||||
| del attn, v | |||||
| o = o.transpose(-2, -3).contiguous() | |||||
| o = o.view(*o.shape[:-2], -1) | |||||
| if g is not None: | |||||
| o = torch.sigmoid(g) * o | |||||
| # merge heads | |||||
| o = nn.functional.linear(o, self.linear_o.weight) | |||||
| return o | |||||
| def get_output_bias(self): | |||||
| return self.linear_o.bias | |||||
| class GlobalAttention(nn.Module): | |||||
| def __init__(self, input_dim, head_dim, num_heads, inf, eps): | |||||
| super(GlobalAttention, self).__init__() | |||||
| self.num_heads = num_heads | |||||
| self.inf = inf | |||||
| self.eps = eps | |||||
| self.linear_q = Linear( | |||||
| input_dim, head_dim * num_heads, bias=False, init='glorot') | |||||
| self.linear_k = Linear(input_dim, head_dim, bias=False, init='glorot') | |||||
| self.linear_v = Linear(input_dim, head_dim, bias=False, init='glorot') | |||||
| self.linear_g = Linear(input_dim, head_dim * num_heads, init='gating') | |||||
| self.linear_o = Linear(head_dim * num_heads, input_dim, init='final') | |||||
| self.sigmoid = nn.Sigmoid() | |||||
| # precompute the 1/sqrt(head_dim) | |||||
| self.norm = head_dim**-0.5 | |||||
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |||||
| # gating | |||||
| g = self.sigmoid(self.linear_g(x)) | |||||
| k = self.linear_k(x) | |||||
| v = self.linear_v(x) | |||||
| q = torch.sum( | |||||
| x * mask.unsqueeze(-1), dim=-2) / ( | |||||
| torch.sum(mask, dim=-1, keepdims=True) + self.eps) | |||||
| q = self.linear_q(q) | |||||
| q *= self.norm | |||||
| q = q.view(q.shape[:-1] + (self.num_heads, -1)) | |||||
| attn = torch.matmul(q, k.transpose(-1, -2)) | |||||
| del q, k | |||||
| attn_mask = gen_attn_mask(mask, -self.inf)[..., :, None, :] | |||||
| attn = softmax_dropout(attn, 0, self.training, mask=attn_mask) | |||||
| o = torch.matmul( | |||||
| attn, | |||||
| v, | |||||
| ) | |||||
| del attn, v | |||||
| g = g.view(g.shape[:-1] + (self.num_heads, -1)) | |||||
| o = o.unsqueeze(-3) * g | |||||
| del g | |||||
| # merge heads | |||||
| o = o.reshape(o.shape[:-2] + (-1, )) | |||||
| return self.linear_o(o) | |||||
| def gen_msa_attn_mask(mask, inf, gen_col_mask=True): | |||||
| row_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] | |||||
| if gen_col_mask: | |||||
| col_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, | |||||
| None, :] | |||||
| return row_mask, col_mask | |||||
| else: | |||||
| return row_mask | |||||
| class MSAAttention(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_in, | |||||
| d_hid, | |||||
| num_heads, | |||||
| pair_bias=False, | |||||
| d_pair=None, | |||||
| ): | |||||
| super(MSAAttention, self).__init__() | |||||
| self.pair_bias = pair_bias | |||||
| self.layer_norm_m = LayerNorm(d_in) | |||||
| self.layer_norm_z = None | |||||
| self.linear_z = None | |||||
| if self.pair_bias: | |||||
| self.layer_norm_z = LayerNorm(d_pair) | |||||
| self.linear_z = Linear( | |||||
| d_pair, num_heads, bias=False, init='normal') | |||||
| self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) | |||||
| @torch.jit.ignore | |||||
| def _chunk( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None, | |||||
| bias: Optional[torch.Tensor] = None, | |||||
| chunk_size: int = None, | |||||
| ) -> torch.Tensor: | |||||
| return chunk_layer( | |||||
| self._attn_forward, | |||||
| { | |||||
| 'm': m, | |||||
| 'mask': mask, | |||||
| 'bias': bias | |||||
| }, | |||||
| chunk_size=chunk_size, | |||||
| num_batch_dims=len(m.shape[:-2]), | |||||
| ) | |||||
| @torch.jit.ignore | |||||
| def _attn_chunk_forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None, | |||||
| bias: Optional[torch.Tensor] = None, | |||||
| chunk_size: Optional[int] = 2560, | |||||
| ) -> torch.Tensor: | |||||
| m = self.layer_norm_m(m) | |||||
| num_chunk = (m.shape[-3] + chunk_size - 1) // chunk_size | |||||
| outputs = [] | |||||
| for i in range(num_chunk): | |||||
| chunk_start = i * chunk_size | |||||
| chunk_end = min(m.shape[-3], chunk_start + chunk_size) | |||||
| cur_m = m[..., chunk_start:chunk_end, :, :] | |||||
| cur_mask = ( | |||||
| mask[..., chunk_start:chunk_end, :, :, :] | |||||
| if mask is not None else None) | |||||
| outputs.append( | |||||
| self.mha(q=cur_m, k=cur_m, v=cur_m, mask=cur_mask, bias=bias)) | |||||
| return torch.concat(outputs, dim=-3) | |||||
| def _attn_forward(self, m, mask, bias: Optional[torch.Tensor] = None): | |||||
| m = self.layer_norm_m(m) | |||||
| return self.mha(q=m, k=m, v=m, mask=mask, bias=bias) | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| z: Optional[torch.Tensor] = None, | |||||
| attn_mask: Optional[torch.Tensor] = None, | |||||
| chunk_size: Optional[int] = None, | |||||
| ) -> torch.Tensor: | |||||
| bias = None | |||||
| if self.pair_bias: | |||||
| z = self.layer_norm_z(z) | |||||
| bias = ( | |||||
| permute_final_dims(self.linear_z(z), | |||||
| (2, 0, 1)).unsqueeze(-4).contiguous()) | |||||
| if chunk_size is not None: | |||||
| m = self._chunk(m, attn_mask, bias, chunk_size) | |||||
| else: | |||||
| attn_chunk_size = 2560 | |||||
| if m.shape[-3] <= attn_chunk_size: | |||||
| m = self._attn_forward(m, attn_mask, bias) | |||||
| else: | |||||
| # reduce the peak memory cost in extra_msa_stack | |||||
| return self._attn_chunk_forward( | |||||
| m, attn_mask, bias, chunk_size=attn_chunk_size) | |||||
| return m | |||||
| def get_output_bias(self): | |||||
| return self.mha.get_output_bias() | |||||
| class MSARowAttentionWithPairBias(MSAAttention): | |||||
| def __init__(self, d_msa, d_pair, d_hid, num_heads): | |||||
| super(MSARowAttentionWithPairBias, self).__init__( | |||||
| d_msa, | |||||
| d_hid, | |||||
| num_heads, | |||||
| pair_bias=True, | |||||
| d_pair=d_pair, | |||||
| ) | |||||
| class MSAColumnAttention(MSAAttention): | |||||
| def __init__(self, d_msa, d_hid, num_heads): | |||||
| super(MSAColumnAttention, self).__init__( | |||||
| d_in=d_msa, | |||||
| d_hid=d_hid, | |||||
| num_heads=num_heads, | |||||
| pair_bias=False, | |||||
| d_pair=None, | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| attn_mask: Optional[torch.Tensor] = None, | |||||
| chunk_size: Optional[int] = None, | |||||
| ) -> torch.Tensor: | |||||
| m = m.transpose(-2, -3) | |||||
| m = super().forward(m, attn_mask=attn_mask, chunk_size=chunk_size) | |||||
| m = m.transpose(-2, -3) | |||||
| return m | |||||
| class MSAColumnGlobalAttention(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_in, | |||||
| d_hid, | |||||
| num_heads, | |||||
| inf=1e9, | |||||
| eps=1e-10, | |||||
| ): | |||||
| super(MSAColumnGlobalAttention, self).__init__() | |||||
| self.layer_norm_m = LayerNorm(d_in) | |||||
| self.global_attention = GlobalAttention( | |||||
| d_in, | |||||
| d_hid, | |||||
| num_heads, | |||||
| inf=inf, | |||||
| eps=eps, | |||||
| ) | |||||
| @torch.jit.ignore | |||||
| def _chunk( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| mask: torch.Tensor, | |||||
| chunk_size: int, | |||||
| ) -> torch.Tensor: | |||||
| return chunk_layer( | |||||
| self._attn_forward, | |||||
| { | |||||
| 'm': m, | |||||
| 'mask': mask | |||||
| }, | |||||
| chunk_size=chunk_size, | |||||
| num_batch_dims=len(m.shape[:-2]), | |||||
| ) | |||||
| def _attn_forward(self, m, mask): | |||||
| m = self.layer_norm_m(m) | |||||
| return self.global_attention(m, mask=mask) | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None, | |||||
| chunk_size: Optional[int] = None, | |||||
| ) -> torch.Tensor: | |||||
| m = m.transpose(-2, -3) | |||||
| mask = mask.transpose(-1, -2) | |||||
| if chunk_size is not None: | |||||
| m = self._chunk(m, mask, chunk_size) | |||||
| else: | |||||
| m = self._attn_forward(m, mask=mask) | |||||
| m = m.transpose(-2, -3) | |||||
| return m | |||||
| def gen_tri_attn_mask(mask, inf): | |||||
| start_mask = gen_attn_mask(mask, -inf)[..., :, None, None, :] | |||||
| end_mask = gen_attn_mask(mask.transpose(-1, -2), -inf)[..., :, None, | |||||
| None, :] | |||||
| return start_mask, end_mask | |||||
| class TriangleAttention(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_in, | |||||
| d_hid, | |||||
| num_heads, | |||||
| starting, | |||||
| ): | |||||
| super(TriangleAttention, self).__init__() | |||||
| self.starting = starting | |||||
| self.layer_norm = LayerNorm(d_in) | |||||
| self.linear = Linear(d_in, num_heads, bias=False, init='normal') | |||||
| self.mha = Attention(d_in, d_in, d_in, d_hid, num_heads) | |||||
| @torch.jit.ignore | |||||
| def _chunk( | |||||
| self, | |||||
| x: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None, | |||||
| bias: Optional[torch.Tensor] = None, | |||||
| chunk_size: int = None, | |||||
| ) -> torch.Tensor: | |||||
| return chunk_layer( | |||||
| self.mha, | |||||
| { | |||||
| 'q': x, | |||||
| 'k': x, | |||||
| 'v': x, | |||||
| 'mask': mask, | |||||
| 'bias': bias | |||||
| }, | |||||
| chunk_size=chunk_size, | |||||
| num_batch_dims=len(x.shape[:-2]), | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| x: torch.Tensor, | |||||
| attn_mask: Optional[torch.Tensor] = None, | |||||
| chunk_size: Optional[int] = None, | |||||
| ) -> torch.Tensor: | |||||
| if not self.starting: | |||||
| x = x.transpose(-2, -3) | |||||
| x = self.layer_norm(x) | |||||
| triangle_bias = ( | |||||
| permute_final_dims(self.linear(x), | |||||
| (2, 0, 1)).unsqueeze(-4).contiguous()) | |||||
| if chunk_size is not None: | |||||
| x = self._chunk(x, attn_mask, triangle_bias, chunk_size) | |||||
| else: | |||||
| x = self.mha(q=x, k=x, v=x, mask=attn_mask, bias=triangle_bias) | |||||
| if not self.starting: | |||||
| x = x.transpose(-2, -3) | |||||
| return x | |||||
| def get_output_bias(self): | |||||
| return self.mha.get_output_bias() | |||||
| class TriangleAttentionStarting(TriangleAttention): | |||||
| __init__ = partialmethod(TriangleAttention.__init__, starting=True) | |||||
| class TriangleAttentionEnding(TriangleAttention): | |||||
| __init__ = partialmethod(TriangleAttention.__init__, starting=False) | |||||
| @@ -0,0 +1,171 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from typing import Dict | |||||
| import torch.nn as nn | |||||
| from unicore.modules import LayerNorm | |||||
| from .common import Linear | |||||
| from .confidence import (predicted_aligned_error, predicted_lddt, | |||||
| predicted_tm_score) | |||||
| class AuxiliaryHeads(nn.Module): | |||||
| def __init__(self, config): | |||||
| super(AuxiliaryHeads, self).__init__() | |||||
| self.plddt = PredictedLDDTHead(**config['plddt'], ) | |||||
| self.distogram = DistogramHead(**config['distogram'], ) | |||||
| self.masked_msa = MaskedMSAHead(**config['masked_msa'], ) | |||||
| if config.experimentally_resolved.enabled: | |||||
| self.experimentally_resolved = ExperimentallyResolvedHead( | |||||
| **config['experimentally_resolved'], ) | |||||
| if config.pae.enabled: | |||||
| self.pae = PredictedAlignedErrorHead(**config.pae, ) | |||||
| self.config = config | |||||
| def forward(self, outputs): | |||||
| aux_out = {} | |||||
| plddt_logits = self.plddt(outputs['sm']['single']) | |||||
| aux_out['plddt_logits'] = plddt_logits | |||||
| aux_out['plddt'] = predicted_lddt(plddt_logits.detach()) | |||||
| distogram_logits = self.distogram(outputs['pair']) | |||||
| aux_out['distogram_logits'] = distogram_logits | |||||
| masked_msa_logits = self.masked_msa(outputs['msa']) | |||||
| aux_out['masked_msa_logits'] = masked_msa_logits | |||||
| if self.config.experimentally_resolved.enabled: | |||||
| exp_res_logits = self.experimentally_resolved(outputs['single']) | |||||
| aux_out['experimentally_resolved_logits'] = exp_res_logits | |||||
| if self.config.pae.enabled: | |||||
| pae_logits = self.pae(outputs['pair']) | |||||
| aux_out['pae_logits'] = pae_logits | |||||
| pae_logits = pae_logits.detach() | |||||
| aux_out.update( | |||||
| predicted_aligned_error( | |||||
| pae_logits, | |||||
| **self.config.pae, | |||||
| )) | |||||
| aux_out['ptm'] = predicted_tm_score( | |||||
| pae_logits, interface=False, **self.config.pae) | |||||
| iptm_weight = self.config.pae.get('iptm_weight', 0.0) | |||||
| if iptm_weight > 0.0: | |||||
| aux_out['iptm'] = predicted_tm_score( | |||||
| pae_logits, | |||||
| interface=True, | |||||
| asym_id=outputs['asym_id'], | |||||
| **self.config.pae, | |||||
| ) | |||||
| aux_out['iptm+ptm'] = ( | |||||
| iptm_weight * aux_out['iptm'] + # noqa W504 | |||||
| (1.0 - iptm_weight) * aux_out['ptm']) | |||||
| return aux_out | |||||
| class PredictedLDDTHead(nn.Module): | |||||
| def __init__(self, num_bins, d_in, d_hid): | |||||
| super(PredictedLDDTHead, self).__init__() | |||||
| self.num_bins = num_bins | |||||
| self.d_in = d_in | |||||
| self.d_hid = d_hid | |||||
| self.layer_norm = LayerNorm(self.d_in) | |||||
| self.linear_1 = Linear(self.d_in, self.d_hid, init='relu') | |||||
| self.linear_2 = Linear(self.d_hid, self.d_hid, init='relu') | |||||
| self.act = nn.GELU() | |||||
| self.linear_3 = Linear(self.d_hid, self.num_bins, init='final') | |||||
| def forward(self, s): | |||||
| s = self.layer_norm(s) | |||||
| s = self.linear_1(s) | |||||
| s = self.act(s) | |||||
| s = self.linear_2(s) | |||||
| s = self.act(s) | |||||
| s = self.linear_3(s) | |||||
| return s | |||||
| class EnhancedHeadBase(nn.Module): | |||||
| def __init__(self, d_in, d_out, disable_enhance_head): | |||||
| super(EnhancedHeadBase, self).__init__() | |||||
| if disable_enhance_head: | |||||
| self.layer_norm = None | |||||
| self.linear_in = None | |||||
| else: | |||||
| self.layer_norm = LayerNorm(d_in) | |||||
| self.linear_in = Linear(d_in, d_in, init='relu') | |||||
| self.act = nn.GELU() | |||||
| self.linear = Linear(d_in, d_out, init='final') | |||||
| def apply_alphafold_original_mode(self): | |||||
| self.layer_norm = None | |||||
| self.linear_in = None | |||||
| def forward(self, x): | |||||
| if self.layer_norm is not None: | |||||
| x = self.layer_norm(x) | |||||
| x = self.act(self.linear_in(x)) | |||||
| logits = self.linear(x) | |||||
| return logits | |||||
| class DistogramHead(EnhancedHeadBase): | |||||
| def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): | |||||
| super(DistogramHead, self).__init__( | |||||
| d_in=d_pair, | |||||
| d_out=num_bins, | |||||
| disable_enhance_head=disable_enhance_head, | |||||
| ) | |||||
| def forward(self, x): | |||||
| logits = super().forward(x) | |||||
| logits = logits + logits.transpose(-2, -3) | |||||
| return logits | |||||
| class PredictedAlignedErrorHead(EnhancedHeadBase): | |||||
| def __init__(self, d_pair, num_bins, disable_enhance_head, **kwargs): | |||||
| super(PredictedAlignedErrorHead, self).__init__( | |||||
| d_in=d_pair, | |||||
| d_out=num_bins, | |||||
| disable_enhance_head=disable_enhance_head, | |||||
| ) | |||||
| class MaskedMSAHead(EnhancedHeadBase): | |||||
| def __init__(self, d_msa, d_out, disable_enhance_head, **kwargs): | |||||
| super(MaskedMSAHead, self).__init__( | |||||
| d_in=d_msa, | |||||
| d_out=d_out, | |||||
| disable_enhance_head=disable_enhance_head, | |||||
| ) | |||||
| class ExperimentallyResolvedHead(EnhancedHeadBase): | |||||
| def __init__(self, d_single, d_out, disable_enhance_head, **kwargs): | |||||
| super(ExperimentallyResolvedHead, self).__init__( | |||||
| d_in=d_single, | |||||
| d_out=d_out, | |||||
| disable_enhance_head=disable_enhance_head, | |||||
| ) | |||||
| @@ -0,0 +1,387 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from functools import partial | |||||
| from typing import Any, Callable, Dict, Iterable, List, Optional | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| import torch.nn.functional as F | |||||
| import torch.utils.checkpoint | |||||
| from unicore.modules import LayerNorm | |||||
| from unicore.utils import tensor_tree_map | |||||
| class Linear(nn.Linear): | |||||
| def __init__( | |||||
| self, | |||||
| d_in: int, | |||||
| d_out: int, | |||||
| bias: bool = True, | |||||
| init: str = 'default', | |||||
| ): | |||||
| super(Linear, self).__init__(d_in, d_out, bias=bias) | |||||
| self.use_bias = bias | |||||
| if self.use_bias: | |||||
| with torch.no_grad(): | |||||
| self.bias.fill_(0) | |||||
| if init == 'default': | |||||
| self._trunc_normal_init(1.0) | |||||
| elif init == 'relu': | |||||
| self._trunc_normal_init(2.0) | |||||
| elif init == 'glorot': | |||||
| self._glorot_uniform_init() | |||||
| elif init == 'gating': | |||||
| self._zero_init(self.use_bias) | |||||
| elif init == 'normal': | |||||
| self._normal_init() | |||||
| elif init == 'final': | |||||
| self._zero_init(False) | |||||
| else: | |||||
| raise ValueError('Invalid init method.') | |||||
| def _trunc_normal_init(self, scale=1.0): | |||||
| # Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) | |||||
| TRUNCATED_NORMAL_STDDEV_FACTOR = 0.87962566103423978 | |||||
| _, fan_in = self.weight.shape | |||||
| scale = scale / max(1, fan_in) | |||||
| std = (scale**0.5) / TRUNCATED_NORMAL_STDDEV_FACTOR | |||||
| nn.init.trunc_normal_(self.weight, mean=0.0, std=std) | |||||
| def _glorot_uniform_init(self): | |||||
| nn.init.xavier_uniform_(self.weight, gain=1) | |||||
| def _zero_init(self, use_bias=True): | |||||
| with torch.no_grad(): | |||||
| self.weight.fill_(0.0) | |||||
| if use_bias: | |||||
| with torch.no_grad(): | |||||
| self.bias.fill_(1.0) | |||||
| def _normal_init(self): | |||||
| torch.nn.init.kaiming_normal_(self.weight, nonlinearity='linear') | |||||
| class Transition(nn.Module): | |||||
| def __init__(self, d_in, n): | |||||
| super(Transition, self).__init__() | |||||
| self.d_in = d_in | |||||
| self.n = n | |||||
| self.layer_norm = LayerNorm(self.d_in) | |||||
| self.linear_1 = Linear(self.d_in, self.n * self.d_in, init='relu') | |||||
| self.act = nn.GELU() | |||||
| self.linear_2 = Linear(self.n * self.d_in, d_in, init='final') | |||||
| def _transition(self, x): | |||||
| x = self.layer_norm(x) | |||||
| x = self.linear_1(x) | |||||
| x = self.act(x) | |||||
| x = self.linear_2(x) | |||||
| return x | |||||
| @torch.jit.ignore | |||||
| def _chunk( | |||||
| self, | |||||
| x: torch.Tensor, | |||||
| chunk_size: int, | |||||
| ) -> torch.Tensor: | |||||
| return chunk_layer( | |||||
| self._transition, | |||||
| {'x': x}, | |||||
| chunk_size=chunk_size, | |||||
| num_batch_dims=len(x.shape[:-2]), | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| x: torch.Tensor, | |||||
| chunk_size: Optional[int] = None, | |||||
| ) -> torch.Tensor: | |||||
| if chunk_size is not None: | |||||
| x = self._chunk(x, chunk_size) | |||||
| else: | |||||
| x = self._transition(x=x) | |||||
| return x | |||||
| class OuterProductMean(nn.Module): | |||||
| def __init__(self, d_msa, d_pair, d_hid, eps=1e-3): | |||||
| super(OuterProductMean, self).__init__() | |||||
| self.d_msa = d_msa | |||||
| self.d_pair = d_pair | |||||
| self.d_hid = d_hid | |||||
| self.eps = eps | |||||
| self.layer_norm = LayerNorm(d_msa) | |||||
| self.linear_1 = Linear(d_msa, d_hid) | |||||
| self.linear_2 = Linear(d_msa, d_hid) | |||||
| self.linear_out = Linear(d_hid**2, d_pair, init='relu') | |||||
| self.act = nn.GELU() | |||||
| self.linear_z = Linear(self.d_pair, self.d_pair, init='final') | |||||
| self.layer_norm_out = LayerNorm(self.d_pair) | |||||
| def _opm(self, a, b): | |||||
| outer = torch.einsum('...bac,...dae->...bdce', a, b) | |||||
| outer = outer.reshape(outer.shape[:-2] + (-1, )) | |||||
| outer = self.linear_out(outer) | |||||
| return outer | |||||
| @torch.jit.ignore | |||||
| def _chunk(self, a: torch.Tensor, b: torch.Tensor, | |||||
| chunk_size: int) -> torch.Tensor: | |||||
| a = a.reshape((-1, ) + a.shape[-3:]) | |||||
| b = b.reshape((-1, ) + b.shape[-3:]) | |||||
| out = [] | |||||
| # TODO: optimize this | |||||
| for a_prime, b_prime in zip(a, b): | |||||
| outer = chunk_layer( | |||||
| partial(self._opm, b=b_prime), | |||||
| {'a': a_prime}, | |||||
| chunk_size=chunk_size, | |||||
| num_batch_dims=1, | |||||
| ) | |||||
| out.append(outer) | |||||
| if len(out) == 1: | |||||
| outer = out[0].unsqueeze(0) | |||||
| else: | |||||
| outer = torch.stack(out, dim=0) | |||||
| outer = outer.reshape(a.shape[:-3] + outer.shape[1:]) | |||||
| return outer | |||||
| def apply_alphafold_original_mode(self): | |||||
| self.linear_z = None | |||||
| self.layer_norm_out = None | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None, | |||||
| chunk_size: Optional[int] = None, | |||||
| ) -> torch.Tensor: | |||||
| m = self.layer_norm(m) | |||||
| mask = mask.unsqueeze(-1) | |||||
| if self.layer_norm_out is not None: | |||||
| # for numerical stability | |||||
| mask = mask * (mask.size(-2)**-0.5) | |||||
| a = self.linear_1(m) | |||||
| b = self.linear_2(m) | |||||
| if self.training: | |||||
| a = a * mask | |||||
| b = b * mask | |||||
| else: | |||||
| a *= mask | |||||
| b *= mask | |||||
| a = a.transpose(-2, -3) | |||||
| b = b.transpose(-2, -3) | |||||
| if chunk_size is not None: | |||||
| z = self._chunk(a, b, chunk_size) | |||||
| else: | |||||
| z = self._opm(a, b) | |||||
| norm = torch.einsum('...abc,...adc->...bdc', mask, mask) | |||||
| z /= self.eps + norm | |||||
| if self.layer_norm_out is not None: | |||||
| z = self.act(z) | |||||
| z = self.layer_norm_out(z) | |||||
| z = self.linear_z(z) | |||||
| return z | |||||
| def residual(residual, x, training): | |||||
| if training: | |||||
| return x + residual | |||||
| else: | |||||
| residual += x | |||||
| return residual | |||||
| @torch.jit.script | |||||
| def fused_bias_dropout_add( | |||||
| x: torch.Tensor, | |||||
| bias: torch.Tensor, | |||||
| residual: torch.Tensor, | |||||
| dropmask: torch.Tensor, | |||||
| prob: float, | |||||
| ) -> torch.Tensor: | |||||
| return (x + bias) * F.dropout(dropmask, p=prob, training=True) + residual | |||||
| @torch.jit.script | |||||
| def fused_bias_dropout_add_inference( | |||||
| x: torch.Tensor, | |||||
| bias: torch.Tensor, | |||||
| residual: torch.Tensor, | |||||
| ) -> torch.Tensor: | |||||
| residual += bias + x | |||||
| return residual | |||||
| def bias_dropout_residual(module, residual, x, dropout_shared_dim, prob, | |||||
| training): | |||||
| bias = module.get_output_bias() | |||||
| if training: | |||||
| shape = list(x.shape) | |||||
| shape[dropout_shared_dim] = 1 | |||||
| with torch.no_grad(): | |||||
| mask = x.new_ones(shape) | |||||
| return fused_bias_dropout_add(x, bias, residual, mask, prob) | |||||
| else: | |||||
| return fused_bias_dropout_add_inference(x, bias, residual) | |||||
| @torch.jit.script | |||||
| def fused_bias_gated_dropout_add( | |||||
| x: torch.Tensor, | |||||
| bias: torch.Tensor, | |||||
| g: torch.Tensor, | |||||
| g_bias: torch.Tensor, | |||||
| residual: torch.Tensor, | |||||
| dropout_mask: torch.Tensor, | |||||
| prob: float, | |||||
| ) -> torch.Tensor: | |||||
| return (torch.sigmoid(g + g_bias) * (x + bias)) * F.dropout( | |||||
| dropout_mask, | |||||
| p=prob, | |||||
| training=True, | |||||
| ) + residual | |||||
| def tri_mul_residual( | |||||
| module, | |||||
| residual, | |||||
| outputs, | |||||
| dropout_shared_dim, | |||||
| prob, | |||||
| training, | |||||
| block_size, | |||||
| ): | |||||
| if training: | |||||
| x, g = outputs | |||||
| bias, g_bias = module.get_output_bias() | |||||
| shape = list(x.shape) | |||||
| shape[dropout_shared_dim] = 1 | |||||
| with torch.no_grad(): | |||||
| mask = x.new_ones(shape) | |||||
| return fused_bias_gated_dropout_add( | |||||
| x, | |||||
| bias, | |||||
| g, | |||||
| g_bias, | |||||
| residual, | |||||
| mask, | |||||
| prob, | |||||
| ) | |||||
| elif block_size is None: | |||||
| x, g = outputs | |||||
| bias, g_bias = module.get_output_bias() | |||||
| residual += (torch.sigmoid(g + g_bias) * (x + bias)) | |||||
| return residual | |||||
| else: | |||||
| # gated is not used here | |||||
| residual += outputs | |||||
| return residual | |||||
| class SimpleModuleList(nn.ModuleList): | |||||
| def __repr__(self): | |||||
| return str(len(self)) + ' X ...\n' + self[0].__repr__() | |||||
| def chunk_layer( | |||||
| layer: Callable, | |||||
| inputs: Dict[str, Any], | |||||
| chunk_size: int, | |||||
| num_batch_dims: int, | |||||
| ) -> Any: | |||||
| # TODO: support inplace add to output | |||||
| if not (len(inputs) > 0): | |||||
| raise ValueError('Must provide at least one input') | |||||
| def _dict_get_shapes(input): | |||||
| shapes = [] | |||||
| if type(input) is torch.Tensor: | |||||
| shapes.append(input.shape) | |||||
| elif type(input) is dict: | |||||
| for v in input.values(): | |||||
| shapes.extend(_dict_get_shapes(v)) | |||||
| elif isinstance(input, Iterable): | |||||
| for v in input: | |||||
| shapes.extend(_dict_get_shapes(v)) | |||||
| else: | |||||
| raise ValueError('Not supported') | |||||
| return shapes | |||||
| inputs = {k: v for k, v in inputs.items() if v is not None} | |||||
| initial_dims = [ | |||||
| shape[:num_batch_dims] for shape in _dict_get_shapes(inputs) | |||||
| ] | |||||
| orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)]) | |||||
| flat_batch_dim = 1 | |||||
| for d in orig_batch_dims: | |||||
| flat_batch_dim *= d | |||||
| num_chunks = (flat_batch_dim + chunk_size - 1) // chunk_size | |||||
| def _flat_inputs(t): | |||||
| t = t.view(-1, *t.shape[num_batch_dims:]) | |||||
| assert ( | |||||
| t.shape[0] == flat_batch_dim or t.shape[0] == 1 | |||||
| ), 'batch dimension must be 1 or equal to the flat batch dimension' | |||||
| return t | |||||
| flat_inputs = tensor_tree_map(_flat_inputs, inputs) | |||||
| out = None | |||||
| for i in range(num_chunks): | |||||
| chunk_start = i * chunk_size | |||||
| chunk_end = min((i + 1) * chunk_size, flat_batch_dim) | |||||
| def select_chunk(t): | |||||
| if t.shape[0] == 1: | |||||
| return t[0:1] | |||||
| else: | |||||
| return t[chunk_start:chunk_end] | |||||
| chunkes = tensor_tree_map(select_chunk, flat_inputs) | |||||
| output_chunk = layer(**chunkes) | |||||
| if out is None: | |||||
| out = tensor_tree_map( | |||||
| lambda t: t.new_zeros((flat_batch_dim, ) + t.shape[1:]), | |||||
| output_chunk) | |||||
| out_type = type(output_chunk) | |||||
| if out_type is tuple: | |||||
| for x, y in zip(out, output_chunk): | |||||
| x[chunk_start:chunk_end] = y | |||||
| elif out_type is torch.Tensor: | |||||
| out[chunk_start:chunk_end] = output_chunk | |||||
| else: | |||||
| raise ValueError('Not supported') | |||||
| # reshape = lambda t: t.view(orig_batch_dims + t.shape[1:]) | |||||
| def reshape(t): | |||||
| return t.view(orig_batch_dims + t.shape[1:]) | |||||
| out = tensor_tree_map(reshape, out) | |||||
| return out | |||||
| @@ -0,0 +1,159 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from typing import Dict, Optional, Tuple | |||||
| import torch | |||||
| def predicted_lddt(plddt_logits: torch.Tensor) -> torch.Tensor: | |||||
| """Computes per-residue pLDDT from logits. | |||||
| Args: | |||||
| logits: [num_res, num_bins] output from the PredictedLDDTHead. | |||||
| Returns: | |||||
| plddt: [num_res] per-residue pLDDT. | |||||
| """ | |||||
| num_bins = plddt_logits.shape[-1] | |||||
| bin_probs = torch.nn.functional.softmax(plddt_logits.float(), dim=-1) | |||||
| bin_width = 1.0 / num_bins | |||||
| bounds = torch.arange( | |||||
| start=0.5 * bin_width, | |||||
| end=1.0, | |||||
| step=bin_width, | |||||
| device=plddt_logits.device) | |||||
| plddt = torch.sum( | |||||
| bin_probs | |||||
| * bounds.view(*((1, ) * len(bin_probs.shape[:-1])), *bounds.shape), | |||||
| dim=-1, | |||||
| ) | |||||
| return plddt | |||||
| def compute_bin_values(breaks: torch.Tensor): | |||||
| """Gets the bin centers from the bin edges. | |||||
| Args: | |||||
| breaks: [num_bins - 1] the error bin edges. | |||||
| Returns: | |||||
| bin_centers: [num_bins] the error bin centers. | |||||
| """ | |||||
| step = breaks[1] - breaks[0] | |||||
| bin_values = breaks + step / 2 | |||||
| bin_values = torch.cat([bin_values, (bin_values[-1] + step).unsqueeze(-1)], | |||||
| dim=0) | |||||
| return bin_values | |||||
| def compute_predicted_aligned_error( | |||||
| bin_edges: torch.Tensor, | |||||
| bin_probs: torch.Tensor, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| """Calculates expected aligned distance errors for every pair of residues. | |||||
| Args: | |||||
| alignment_confidence_breaks: [num_bins - 1] the error bin edges. | |||||
| aligned_distance_error_probs: [num_res, num_res, num_bins] the predicted | |||||
| probs for each error bin, for each pair of residues. | |||||
| Returns: | |||||
| predicted_aligned_error: [num_res, num_res] the expected aligned distance | |||||
| error for each pair of residues. | |||||
| max_predicted_aligned_error: The maximum predicted error possible. | |||||
| """ | |||||
| bin_values = compute_bin_values(bin_edges) | |||||
| return torch.sum(bin_probs * bin_values, dim=-1) | |||||
| def predicted_aligned_error( | |||||
| pae_logits: torch.Tensor, | |||||
| max_bin: int = 31, | |||||
| num_bins: int = 64, | |||||
| **kwargs, | |||||
| ) -> Dict[str, torch.Tensor]: | |||||
| """Computes aligned confidence metrics from logits. | |||||
| Args: | |||||
| logits: [num_res, num_res, num_bins] the logits output from | |||||
| PredictedAlignedErrorHead. | |||||
| breaks: [num_bins - 1] the error bin edges. | |||||
| Returns: | |||||
| aligned_confidence_probs: [num_res, num_res, num_bins] the predicted | |||||
| aligned error probabilities over bins for each residue pair. | |||||
| predicted_aligned_error: [num_res, num_res] the expected aligned distance | |||||
| error for each pair of residues. | |||||
| max_predicted_aligned_error: The maximum predicted error possible. | |||||
| """ | |||||
| bin_probs = torch.nn.functional.softmax(pae_logits.float(), dim=-1) | |||||
| bin_edges = torch.linspace( | |||||
| 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) | |||||
| predicted_aligned_error = compute_predicted_aligned_error( | |||||
| bin_edges=bin_edges, | |||||
| bin_probs=bin_probs, | |||||
| ) | |||||
| return { | |||||
| 'aligned_error_probs_per_bin': bin_probs, | |||||
| 'predicted_aligned_error': predicted_aligned_error, | |||||
| } | |||||
| def predicted_tm_score( | |||||
| pae_logits: torch.Tensor, | |||||
| residue_weights: Optional[torch.Tensor] = None, | |||||
| max_bin: int = 31, | |||||
| num_bins: int = 64, | |||||
| eps: float = 1e-8, | |||||
| asym_id: Optional[torch.Tensor] = None, | |||||
| interface: bool = False, | |||||
| **kwargs, | |||||
| ) -> torch.Tensor: | |||||
| """Computes predicted TM alignment or predicted interface TM alignment score. | |||||
| Args: | |||||
| logits: [num_res, num_res, num_bins] the logits output from | |||||
| PredictedAlignedErrorHead. | |||||
| breaks: [num_bins] the error bins. | |||||
| residue_weights: [num_res] the per residue weights to use for the | |||||
| expectation. | |||||
| asym_id: [num_res] the asymmetric unit ID - the chain ID. Only needed for | |||||
| ipTM calculation, i.e. when interface=True. | |||||
| interface: If True, interface predicted TM score is computed. | |||||
| Returns: | |||||
| ptm_score: The predicted TM alignment or the predicted iTM score. | |||||
| """ | |||||
| pae_logits = pae_logits.float() | |||||
| if residue_weights is None: | |||||
| residue_weights = pae_logits.new_ones(pae_logits.shape[:-2]) | |||||
| breaks = torch.linspace( | |||||
| 0, max_bin, steps=(num_bins - 1), device=pae_logits.device) | |||||
| def tm_kernal(nres): | |||||
| clipped_n = max(nres, 19) | |||||
| d0 = 1.24 * (clipped_n - 15)**(1.0 / 3.0) - 1.8 | |||||
| return lambda x: 1.0 / (1.0 + (x / d0)**2) | |||||
| def rmsd_kernal(eps): # leave for compute pRMS | |||||
| return lambda x: 1. / (x + eps) | |||||
| bin_centers = compute_bin_values(breaks) | |||||
| probs = torch.nn.functional.softmax(pae_logits, dim=-1) | |||||
| tm_per_bin = tm_kernal(nres=pae_logits.shape[-2])(bin_centers) | |||||
| # tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2)) | |||||
| # rmsd_per_bin = rmsd_kernal()(bin_centers) | |||||
| predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1) | |||||
| pair_mask = predicted_tm_term.new_ones(predicted_tm_term.shape) | |||||
| if interface: | |||||
| assert asym_id is not None, 'must provide asym_id for iptm calculation.' | |||||
| pair_mask *= asym_id[..., :, None] != asym_id[..., None, :] | |||||
| predicted_tm_term *= pair_mask | |||||
| pair_residue_weights = pair_mask * ( | |||||
| residue_weights[None, :] * residue_weights[:, None]) | |||||
| normed_residue_mask = pair_residue_weights / ( | |||||
| eps + pair_residue_weights.sum(dim=-1, keepdim=True)) | |||||
| per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1) | |||||
| weighted = per_alignment * residue_weights | |||||
| ret = per_alignment.gather( | |||||
| dim=-1, index=weighted.max(dim=-1, | |||||
| keepdim=True).indices).squeeze(dim=-1) | |||||
| return ret | |||||
| @@ -0,0 +1,290 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from typing import Optional, Tuple | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.modules import LayerNorm | |||||
| from unicore.utils import one_hot | |||||
| from .common import Linear, SimpleModuleList, residual | |||||
| class InputEmbedder(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| tf_dim: int, | |||||
| msa_dim: int, | |||||
| d_pair: int, | |||||
| d_msa: int, | |||||
| relpos_k: int, | |||||
| use_chain_relative: bool = False, | |||||
| max_relative_chain: Optional[int] = None, | |||||
| **kwargs, | |||||
| ): | |||||
| super(InputEmbedder, self).__init__() | |||||
| self.tf_dim = tf_dim | |||||
| self.msa_dim = msa_dim | |||||
| self.d_pair = d_pair | |||||
| self.d_msa = d_msa | |||||
| self.linear_tf_z_i = Linear(tf_dim, d_pair) | |||||
| self.linear_tf_z_j = Linear(tf_dim, d_pair) | |||||
| self.linear_tf_m = Linear(tf_dim, d_msa) | |||||
| self.linear_msa_m = Linear(msa_dim, d_msa) | |||||
| # RPE stuff | |||||
| self.relpos_k = relpos_k | |||||
| self.use_chain_relative = use_chain_relative | |||||
| self.max_relative_chain = max_relative_chain | |||||
| if not self.use_chain_relative: | |||||
| self.num_bins = 2 * self.relpos_k + 1 | |||||
| else: | |||||
| self.num_bins = 2 * self.relpos_k + 2 | |||||
| self.num_bins += 1 # entity id | |||||
| self.num_bins += 2 * max_relative_chain + 2 | |||||
| self.linear_relpos = Linear(self.num_bins, d_pair) | |||||
| def _relpos_indices( | |||||
| self, | |||||
| res_id: torch.Tensor, | |||||
| sym_id: Optional[torch.Tensor] = None, | |||||
| asym_id: Optional[torch.Tensor] = None, | |||||
| entity_id: Optional[torch.Tensor] = None, | |||||
| ): | |||||
| max_rel_res = self.relpos_k | |||||
| rp = res_id[..., None] - res_id[..., None, :] | |||||
| rp = rp.clip(-max_rel_res, max_rel_res) + max_rel_res | |||||
| if not self.use_chain_relative: | |||||
| return rp | |||||
| else: | |||||
| asym_id_same = asym_id[..., :, None] == asym_id[..., None, :] | |||||
| rp[~asym_id_same] = 2 * max_rel_res + 1 | |||||
| entity_id_same = entity_id[..., :, None] == entity_id[..., None, :] | |||||
| rp_entity_id = entity_id_same.type(rp.dtype)[..., None] | |||||
| rel_sym_id = sym_id[..., :, None] - sym_id[..., None, :] | |||||
| max_rel_chain = self.max_relative_chain | |||||
| clipped_rel_chain = torch.clamp( | |||||
| rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain) | |||||
| clipped_rel_chain[~entity_id_same] = 2 * max_rel_chain + 1 | |||||
| return rp, rp_entity_id, clipped_rel_chain | |||||
| def relpos_emb( | |||||
| self, | |||||
| res_id: torch.Tensor, | |||||
| sym_id: Optional[torch.Tensor] = None, | |||||
| asym_id: Optional[torch.Tensor] = None, | |||||
| entity_id: Optional[torch.Tensor] = None, | |||||
| num_sym: Optional[torch.Tensor] = None, | |||||
| ): | |||||
| dtype = self.linear_relpos.weight.dtype | |||||
| if not self.use_chain_relative: | |||||
| rp = self._relpos_indices(res_id=res_id) | |||||
| return self.linear_relpos( | |||||
| one_hot(rp, num_classes=self.num_bins, dtype=dtype)) | |||||
| else: | |||||
| rp, rp_entity_id, rp_rel_chain = self._relpos_indices( | |||||
| res_id=res_id, | |||||
| sym_id=sym_id, | |||||
| asym_id=asym_id, | |||||
| entity_id=entity_id) | |||||
| rp = one_hot(rp, num_classes=(2 * self.relpos_k + 2), dtype=dtype) | |||||
| rp_entity_id = rp_entity_id.type(dtype) | |||||
| rp_rel_chain = one_hot( | |||||
| rp_rel_chain, | |||||
| num_classes=(2 * self.max_relative_chain + 2), | |||||
| dtype=dtype) | |||||
| return self.linear_relpos( | |||||
| torch.cat([rp, rp_entity_id, rp_rel_chain], dim=-1)) | |||||
| def forward( | |||||
| self, | |||||
| tf: torch.Tensor, | |||||
| msa: torch.Tensor, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| # [*, N_res, d_pair] | |||||
| if self.tf_dim == 21: | |||||
| # multimer use 21 target dim | |||||
| tf = tf[..., 1:] | |||||
| # convert type if necessary | |||||
| tf = tf.type(self.linear_tf_z_i.weight.dtype) | |||||
| msa = msa.type(self.linear_tf_z_i.weight.dtype) | |||||
| n_clust = msa.shape[-3] | |||||
| msa_emb = self.linear_msa_m(msa) | |||||
| # target_feat (aatype) into msa representation | |||||
| tf_m = ( | |||||
| self.linear_tf_m(tf).unsqueeze(-3).expand( | |||||
| ((-1, ) * len(tf.shape[:-2]) + # noqa W504 | |||||
| (n_clust, -1, -1)))) | |||||
| msa_emb += tf_m | |||||
| tf_emb_i = self.linear_tf_z_i(tf) | |||||
| tf_emb_j = self.linear_tf_z_j(tf) | |||||
| pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] | |||||
| return msa_emb, pair_emb | |||||
| class RecyclingEmbedder(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_msa: int, | |||||
| d_pair: int, | |||||
| min_bin: float, | |||||
| max_bin: float, | |||||
| num_bins: int, | |||||
| inf: float = 1e8, | |||||
| **kwargs, | |||||
| ): | |||||
| super(RecyclingEmbedder, self).__init__() | |||||
| self.d_msa = d_msa | |||||
| self.d_pair = d_pair | |||||
| self.min_bin = min_bin | |||||
| self.max_bin = max_bin | |||||
| self.num_bins = num_bins | |||||
| self.inf = inf | |||||
| self.squared_bins = None | |||||
| self.linear = Linear(self.num_bins, self.d_pair) | |||||
| self.layer_norm_m = LayerNorm(self.d_msa) | |||||
| self.layer_norm_z = LayerNorm(self.d_pair) | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| z: torch.Tensor, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| m_update = self.layer_norm_m(m) | |||||
| z_update = self.layer_norm_z(z) | |||||
| return m_update, z_update | |||||
| def recyle_pos( | |||||
| self, | |||||
| x: torch.Tensor, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| if self.squared_bins is None: | |||||
| bins = torch.linspace( | |||||
| self.min_bin, | |||||
| self.max_bin, | |||||
| self.num_bins, | |||||
| dtype=torch.float if self.training else x.dtype, | |||||
| device=x.device, | |||||
| requires_grad=False, | |||||
| ) | |||||
| self.squared_bins = bins**2 | |||||
| upper = torch.cat( | |||||
| [self.squared_bins[1:], | |||||
| self.squared_bins.new_tensor([self.inf])], | |||||
| dim=-1) | |||||
| if self.training: | |||||
| x = x.float() | |||||
| d = torch.sum( | |||||
| (x[..., None, :] - x[..., None, :, :])**2, dim=-1, keepdims=True) | |||||
| d = ((d > self.squared_bins) * # noqa W504 | |||||
| (d < upper)).type(self.linear.weight.dtype) | |||||
| d = self.linear(d) | |||||
| return d | |||||
| class TemplateAngleEmbedder(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_in: int, | |||||
| d_out: int, | |||||
| **kwargs, | |||||
| ): | |||||
| super(TemplateAngleEmbedder, self).__init__() | |||||
| self.d_out = d_out | |||||
| self.d_in = d_in | |||||
| self.linear_1 = Linear(self.d_in, self.d_out, init='relu') | |||||
| self.act = nn.GELU() | |||||
| self.linear_2 = Linear(self.d_out, self.d_out, init='relu') | |||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
| x = self.linear_1(x.type(self.linear_1.weight.dtype)) | |||||
| x = self.act(x) | |||||
| x = self.linear_2(x) | |||||
| return x | |||||
| class TemplatePairEmbedder(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_in: int, | |||||
| v2_d_in: list, | |||||
| d_out: int, | |||||
| d_pair: int, | |||||
| v2_feature: bool = False, | |||||
| **kwargs, | |||||
| ): | |||||
| super(TemplatePairEmbedder, self).__init__() | |||||
| self.d_out = d_out | |||||
| self.v2_feature = v2_feature | |||||
| if self.v2_feature: | |||||
| self.d_in = v2_d_in | |||||
| self.linear = SimpleModuleList() | |||||
| for d_in in self.d_in: | |||||
| self.linear.append(Linear(d_in, self.d_out, init='relu')) | |||||
| self.z_layer_norm = LayerNorm(d_pair) | |||||
| self.z_linear = Linear(d_pair, self.d_out, init='relu') | |||||
| else: | |||||
| self.d_in = d_in | |||||
| self.linear = Linear(self.d_in, self.d_out, init='relu') | |||||
| def forward( | |||||
| self, | |||||
| x, | |||||
| z, | |||||
| ) -> torch.Tensor: | |||||
| if not self.v2_feature: | |||||
| x = self.linear(x.type(self.linear.weight.dtype)) | |||||
| return x | |||||
| else: | |||||
| dtype = self.z_linear.weight.dtype | |||||
| t = self.linear[0](x[0].type(dtype)) | |||||
| for i, s in enumerate(x[1:]): | |||||
| t = residual(t, self.linear[i + 1](s.type(dtype)), | |||||
| self.training) | |||||
| t = residual(t, self.z_linear(self.z_layer_norm(z)), self.training) | |||||
| return t | |||||
| class ExtraMSAEmbedder(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_in: int, | |||||
| d_out: int, | |||||
| **kwargs, | |||||
| ): | |||||
| super(ExtraMSAEmbedder, self).__init__() | |||||
| self.d_in = d_in | |||||
| self.d_out = d_out | |||||
| self.linear = Linear(self.d_in, self.d_out) | |||||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |||||
| return self.linear(x.type(self.linear.weight.dtype)) | |||||
| @@ -0,0 +1,362 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from functools import partial | |||||
| from typing import Optional, Tuple | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.utils import checkpoint_sequential | |||||
| from .attentions import (MSAColumnAttention, MSAColumnGlobalAttention, | |||||
| MSARowAttentionWithPairBias, TriangleAttentionEnding, | |||||
| TriangleAttentionStarting) | |||||
| from .common import (Linear, OuterProductMean, SimpleModuleList, Transition, | |||||
| bias_dropout_residual, residual, tri_mul_residual) | |||||
| from .triangle_multiplication import (TriangleMultiplicationIncoming, | |||||
| TriangleMultiplicationOutgoing) | |||||
| class EvoformerIteration(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_msa: int, | |||||
| d_pair: int, | |||||
| d_hid_msa_att: int, | |||||
| d_hid_opm: int, | |||||
| d_hid_mul: int, | |||||
| d_hid_pair_att: int, | |||||
| num_heads_msa: int, | |||||
| num_heads_pair: int, | |||||
| transition_n: int, | |||||
| msa_dropout: float, | |||||
| pair_dropout: float, | |||||
| outer_product_mean_first: bool, | |||||
| inf: float, | |||||
| eps: float, | |||||
| _is_extra_msa_stack: bool = False, | |||||
| ): | |||||
| super(EvoformerIteration, self).__init__() | |||||
| self._is_extra_msa_stack = _is_extra_msa_stack | |||||
| self.outer_product_mean_first = outer_product_mean_first | |||||
| self.msa_att_row = MSARowAttentionWithPairBias( | |||||
| d_msa=d_msa, | |||||
| d_pair=d_pair, | |||||
| d_hid=d_hid_msa_att, | |||||
| num_heads=num_heads_msa, | |||||
| ) | |||||
| if _is_extra_msa_stack: | |||||
| self.msa_att_col = MSAColumnGlobalAttention( | |||||
| d_in=d_msa, | |||||
| d_hid=d_hid_msa_att, | |||||
| num_heads=num_heads_msa, | |||||
| inf=inf, | |||||
| eps=eps, | |||||
| ) | |||||
| else: | |||||
| self.msa_att_col = MSAColumnAttention( | |||||
| d_msa, | |||||
| d_hid_msa_att, | |||||
| num_heads_msa, | |||||
| ) | |||||
| self.msa_transition = Transition( | |||||
| d_in=d_msa, | |||||
| n=transition_n, | |||||
| ) | |||||
| self.outer_product_mean = OuterProductMean( | |||||
| d_msa, | |||||
| d_pair, | |||||
| d_hid_opm, | |||||
| ) | |||||
| self.tri_mul_out = TriangleMultiplicationOutgoing( | |||||
| d_pair, | |||||
| d_hid_mul, | |||||
| ) | |||||
| self.tri_mul_in = TriangleMultiplicationIncoming( | |||||
| d_pair, | |||||
| d_hid_mul, | |||||
| ) | |||||
| self.tri_att_start = TriangleAttentionStarting( | |||||
| d_pair, | |||||
| d_hid_pair_att, | |||||
| num_heads_pair, | |||||
| ) | |||||
| self.tri_att_end = TriangleAttentionEnding( | |||||
| d_pair, | |||||
| d_hid_pair_att, | |||||
| num_heads_pair, | |||||
| ) | |||||
| self.pair_transition = Transition( | |||||
| d_in=d_pair, | |||||
| n=transition_n, | |||||
| ) | |||||
| self.row_dropout_share_dim = -3 | |||||
| self.col_dropout_share_dim = -2 | |||||
| self.msa_dropout = msa_dropout | |||||
| self.pair_dropout = pair_dropout | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| z: torch.Tensor, | |||||
| msa_mask: torch.Tensor, | |||||
| pair_mask: torch.Tensor, | |||||
| msa_row_attn_mask: torch.Tensor, | |||||
| msa_col_attn_mask: Optional[torch.Tensor], | |||||
| tri_start_attn_mask: torch.Tensor, | |||||
| tri_end_attn_mask: torch.Tensor, | |||||
| chunk_size: Optional[int] = None, | |||||
| block_size: Optional[int] = None, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| if self.outer_product_mean_first: | |||||
| z = residual( | |||||
| z, | |||||
| self.outer_product_mean( | |||||
| m, mask=msa_mask, chunk_size=chunk_size), self.training) | |||||
| m = bias_dropout_residual( | |||||
| self.msa_att_row, | |||||
| m, | |||||
| self.msa_att_row( | |||||
| m, z=z, attn_mask=msa_row_attn_mask, chunk_size=chunk_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.msa_dropout, | |||||
| self.training, | |||||
| ) | |||||
| if self._is_extra_msa_stack: | |||||
| m = residual( | |||||
| m, self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size), | |||||
| self.training) | |||||
| else: | |||||
| m = bias_dropout_residual( | |||||
| self.msa_att_col, | |||||
| m, | |||||
| self.msa_att_col( | |||||
| m, attn_mask=msa_col_attn_mask, chunk_size=chunk_size), | |||||
| self.col_dropout_share_dim, | |||||
| self.msa_dropout, | |||||
| self.training, | |||||
| ) | |||||
| m = residual(m, self.msa_transition(m, chunk_size=chunk_size), | |||||
| self.training) | |||||
| if not self.outer_product_mean_first: | |||||
| z = residual( | |||||
| z, | |||||
| self.outer_product_mean( | |||||
| m, mask=msa_mask, chunk_size=chunk_size), self.training) | |||||
| z = tri_mul_residual( | |||||
| self.tri_mul_out, | |||||
| z, | |||||
| self.tri_mul_out(z, mask=pair_mask, block_size=block_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.pair_dropout, | |||||
| self.training, | |||||
| block_size=block_size, | |||||
| ) | |||||
| z = tri_mul_residual( | |||||
| self.tri_mul_in, | |||||
| z, | |||||
| self.tri_mul_in(z, mask=pair_mask, block_size=block_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.pair_dropout, | |||||
| self.training, | |||||
| block_size=block_size, | |||||
| ) | |||||
| z = bias_dropout_residual( | |||||
| self.tri_att_start, | |||||
| z, | |||||
| self.tri_att_start( | |||||
| z, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.pair_dropout, | |||||
| self.training, | |||||
| ) | |||||
| z = bias_dropout_residual( | |||||
| self.tri_att_end, | |||||
| z, | |||||
| self.tri_att_end( | |||||
| z, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), | |||||
| self.col_dropout_share_dim, | |||||
| self.pair_dropout, | |||||
| self.training, | |||||
| ) | |||||
| z = residual(z, self.pair_transition(z, chunk_size=chunk_size), | |||||
| self.training) | |||||
| return m, z | |||||
| class EvoformerStack(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_msa: int, | |||||
| d_pair: int, | |||||
| d_hid_msa_att: int, | |||||
| d_hid_opm: int, | |||||
| d_hid_mul: int, | |||||
| d_hid_pair_att: int, | |||||
| d_single: int, | |||||
| num_heads_msa: int, | |||||
| num_heads_pair: int, | |||||
| num_blocks: int, | |||||
| transition_n: int, | |||||
| msa_dropout: float, | |||||
| pair_dropout: float, | |||||
| outer_product_mean_first: bool, | |||||
| inf: float, | |||||
| eps: float, | |||||
| _is_extra_msa_stack: bool = False, | |||||
| **kwargs, | |||||
| ): | |||||
| super(EvoformerStack, self).__init__() | |||||
| self._is_extra_msa_stack = _is_extra_msa_stack | |||||
| self.blocks = SimpleModuleList() | |||||
| for _ in range(num_blocks): | |||||
| self.blocks.append( | |||||
| EvoformerIteration( | |||||
| d_msa=d_msa, | |||||
| d_pair=d_pair, | |||||
| d_hid_msa_att=d_hid_msa_att, | |||||
| d_hid_opm=d_hid_opm, | |||||
| d_hid_mul=d_hid_mul, | |||||
| d_hid_pair_att=d_hid_pair_att, | |||||
| num_heads_msa=num_heads_msa, | |||||
| num_heads_pair=num_heads_pair, | |||||
| transition_n=transition_n, | |||||
| msa_dropout=msa_dropout, | |||||
| pair_dropout=pair_dropout, | |||||
| outer_product_mean_first=outer_product_mean_first, | |||||
| inf=inf, | |||||
| eps=eps, | |||||
| _is_extra_msa_stack=_is_extra_msa_stack, | |||||
| )) | |||||
| if not self._is_extra_msa_stack: | |||||
| self.linear = Linear(d_msa, d_single) | |||||
| else: | |||||
| self.linear = None | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| z: torch.Tensor, | |||||
| msa_mask: torch.Tensor, | |||||
| pair_mask: torch.Tensor, | |||||
| msa_row_attn_mask: torch.Tensor, | |||||
| msa_col_attn_mask: torch.Tensor, | |||||
| tri_start_attn_mask: torch.Tensor, | |||||
| tri_end_attn_mask: torch.Tensor, | |||||
| chunk_size: int, | |||||
| block_size: int, | |||||
| ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: | |||||
| blocks = [ | |||||
| partial( | |||||
| b, | |||||
| msa_mask=msa_mask, | |||||
| pair_mask=pair_mask, | |||||
| msa_row_attn_mask=msa_row_attn_mask, | |||||
| msa_col_attn_mask=msa_col_attn_mask, | |||||
| tri_start_attn_mask=tri_start_attn_mask, | |||||
| tri_end_attn_mask=tri_end_attn_mask, | |||||
| chunk_size=chunk_size, | |||||
| block_size=block_size) for b in self.blocks | |||||
| ] | |||||
| m, z = checkpoint_sequential( | |||||
| blocks, | |||||
| input=(m, z), | |||||
| ) | |||||
| s = None | |||||
| if not self._is_extra_msa_stack: | |||||
| seq_dim = -3 | |||||
| index = torch.tensor([0], device=m.device) | |||||
| s = self.linear(torch.index_select(m, dim=seq_dim, index=index)) | |||||
| s = s.squeeze(seq_dim) | |||||
| return m, z, s | |||||
| class ExtraMSAStack(EvoformerStack): | |||||
| def __init__( | |||||
| self, | |||||
| d_msa: int, | |||||
| d_pair: int, | |||||
| d_hid_msa_att: int, | |||||
| d_hid_opm: int, | |||||
| d_hid_mul: int, | |||||
| d_hid_pair_att: int, | |||||
| num_heads_msa: int, | |||||
| num_heads_pair: int, | |||||
| num_blocks: int, | |||||
| transition_n: int, | |||||
| msa_dropout: float, | |||||
| pair_dropout: float, | |||||
| outer_product_mean_first: bool, | |||||
| inf: float, | |||||
| eps: float, | |||||
| **kwargs, | |||||
| ): | |||||
| super(ExtraMSAStack, self).__init__( | |||||
| d_msa=d_msa, | |||||
| d_pair=d_pair, | |||||
| d_hid_msa_att=d_hid_msa_att, | |||||
| d_hid_opm=d_hid_opm, | |||||
| d_hid_mul=d_hid_mul, | |||||
| d_hid_pair_att=d_hid_pair_att, | |||||
| d_single=None, | |||||
| num_heads_msa=num_heads_msa, | |||||
| num_heads_pair=num_heads_pair, | |||||
| num_blocks=num_blocks, | |||||
| transition_n=transition_n, | |||||
| msa_dropout=msa_dropout, | |||||
| pair_dropout=pair_dropout, | |||||
| outer_product_mean_first=outer_product_mean_first, | |||||
| inf=inf, | |||||
| eps=eps, | |||||
| _is_extra_msa_stack=True, | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| m: torch.Tensor, | |||||
| z: torch.Tensor, | |||||
| msa_mask: Optional[torch.Tensor] = None, | |||||
| pair_mask: Optional[torch.Tensor] = None, | |||||
| msa_row_attn_mask: torch.Tensor = None, | |||||
| msa_col_attn_mask: torch.Tensor = None, | |||||
| tri_start_attn_mask: torch.Tensor = None, | |||||
| tri_end_attn_mask: torch.Tensor = None, | |||||
| chunk_size: int = None, | |||||
| block_size: int = None, | |||||
| ) -> torch.Tensor: | |||||
| _, z, _ = super().forward( | |||||
| m, | |||||
| z, | |||||
| msa_mask=msa_mask, | |||||
| pair_mask=pair_mask, | |||||
| msa_row_attn_mask=msa_row_attn_mask, | |||||
| msa_col_attn_mask=msa_col_attn_mask, | |||||
| tri_start_attn_mask=tri_start_attn_mask, | |||||
| tri_end_attn_mask=tri_end_attn_mask, | |||||
| chunk_size=chunk_size, | |||||
| block_size=block_size) | |||||
| return z | |||||
| @@ -0,0 +1,195 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from typing import Dict | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.utils import batched_gather, one_hot | |||||
| from modelscope.models.science.unifold.data import residue_constants as rc | |||||
| from .frame import Frame | |||||
| def pseudo_beta_fn(aatype, all_atom_positions, all_atom_masks): | |||||
| is_gly = aatype == rc.restype_order['G'] | |||||
| ca_idx = rc.atom_order['CA'] | |||||
| cb_idx = rc.atom_order['CB'] | |||||
| pseudo_beta = torch.where( | |||||
| is_gly[..., None].expand(*((-1, ) * len(is_gly.shape)), 3), | |||||
| all_atom_positions[..., ca_idx, :], | |||||
| all_atom_positions[..., cb_idx, :], | |||||
| ) | |||||
| if all_atom_masks is not None: | |||||
| pseudo_beta_mask = torch.where( | |||||
| is_gly, | |||||
| all_atom_masks[..., ca_idx], | |||||
| all_atom_masks[..., cb_idx], | |||||
| ) | |||||
| return pseudo_beta, pseudo_beta_mask | |||||
| else: | |||||
| return pseudo_beta | |||||
| def atom14_to_atom37(atom14, batch): | |||||
| atom37_data = batched_gather( | |||||
| atom14, | |||||
| batch['residx_atom37_to_atom14'], | |||||
| dim=-2, | |||||
| num_batch_dims=len(atom14.shape[:-2]), | |||||
| ) | |||||
| atom37_data = atom37_data * batch['atom37_atom_exists'][..., None] | |||||
| return atom37_data | |||||
| def build_template_angle_feat(template_feats, v2_feature=False): | |||||
| template_aatype = template_feats['template_aatype'] | |||||
| torsion_angles_sin_cos = template_feats['template_torsion_angles_sin_cos'] | |||||
| torsion_angles_mask = template_feats['template_torsion_angles_mask'] | |||||
| if not v2_feature: | |||||
| alt_torsion_angles_sin_cos = template_feats[ | |||||
| 'template_alt_torsion_angles_sin_cos'] | |||||
| template_angle_feat = torch.cat( | |||||
| [ | |||||
| one_hot(template_aatype, 22), | |||||
| torsion_angles_sin_cos.reshape( | |||||
| *torsion_angles_sin_cos.shape[:-2], 14), | |||||
| alt_torsion_angles_sin_cos.reshape( | |||||
| *alt_torsion_angles_sin_cos.shape[:-2], 14), | |||||
| torsion_angles_mask, | |||||
| ], | |||||
| dim=-1, | |||||
| ) | |||||
| template_angle_mask = torsion_angles_mask[..., 2] | |||||
| else: | |||||
| chi_mask = torsion_angles_mask[..., 3:] | |||||
| chi_angles_sin = torsion_angles_sin_cos[..., 3:, 0] * chi_mask | |||||
| chi_angles_cos = torsion_angles_sin_cos[..., 3:, 1] * chi_mask | |||||
| template_angle_feat = torch.cat( | |||||
| [ | |||||
| one_hot(template_aatype, 22), | |||||
| chi_angles_sin, | |||||
| chi_angles_cos, | |||||
| chi_mask, | |||||
| ], | |||||
| dim=-1, | |||||
| ) | |||||
| template_angle_mask = chi_mask[..., 0] | |||||
| return template_angle_feat, template_angle_mask | |||||
| def build_template_pair_feat( | |||||
| batch, | |||||
| min_bin, | |||||
| max_bin, | |||||
| num_bins, | |||||
| eps=1e-20, | |||||
| inf=1e8, | |||||
| ): | |||||
| template_mask = batch['template_pseudo_beta_mask'] | |||||
| template_mask_2d = template_mask[..., None] * template_mask[..., None, :] | |||||
| tpb = batch['template_pseudo_beta'] | |||||
| dgram = torch.sum( | |||||
| (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) | |||||
| lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 | |||||
| upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) | |||||
| dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) | |||||
| to_concat = [dgram, template_mask_2d[..., None]] | |||||
| aatype_one_hot = nn.functional.one_hot( | |||||
| batch['template_aatype'], | |||||
| rc.restype_num + 2, | |||||
| ) | |||||
| n_res = batch['template_aatype'].shape[-1] | |||||
| to_concat.append(aatype_one_hot[..., None, :, :].expand( | |||||
| *aatype_one_hot.shape[:-2], n_res, -1, -1)) | |||||
| to_concat.append(aatype_one_hot[..., | |||||
| None, :].expand(*aatype_one_hot.shape[:-2], | |||||
| -1, n_res, -1)) | |||||
| to_concat.append(template_mask_2d.new_zeros(*template_mask_2d.shape, 3)) | |||||
| to_concat.append(template_mask_2d[..., None]) | |||||
| act = torch.cat(to_concat, dim=-1) | |||||
| act = act * template_mask_2d[..., None] | |||||
| return act | |||||
| def build_template_pair_feat_v2( | |||||
| batch, | |||||
| min_bin, | |||||
| max_bin, | |||||
| num_bins, | |||||
| multichain_mask_2d=None, | |||||
| eps=1e-20, | |||||
| inf=1e8, | |||||
| ): | |||||
| template_mask = batch['template_pseudo_beta_mask'] | |||||
| template_mask_2d = template_mask[..., None] * template_mask[..., None, :] | |||||
| if multichain_mask_2d is not None: | |||||
| template_mask_2d *= multichain_mask_2d | |||||
| tpb = batch['template_pseudo_beta'] | |||||
| dgram = torch.sum( | |||||
| (tpb[..., None, :] - tpb[..., None, :, :])**2, dim=-1, keepdim=True) | |||||
| lower = torch.linspace(min_bin, max_bin, num_bins, device=tpb.device)**2 | |||||
| upper = torch.cat([lower[1:], lower.new_tensor([inf])], dim=-1) | |||||
| dgram = ((dgram > lower) * (dgram < upper)).type(dgram.dtype) | |||||
| dgram *= template_mask_2d[..., None] | |||||
| to_concat = [dgram, template_mask_2d[..., None]] | |||||
| aatype_one_hot = one_hot( | |||||
| batch['template_aatype'], | |||||
| rc.restype_num + 2, | |||||
| ) | |||||
| n_res = batch['template_aatype'].shape[-1] | |||||
| to_concat.append(aatype_one_hot[..., None, :, :].expand( | |||||
| *aatype_one_hot.shape[:-2], n_res, -1, -1)) | |||||
| to_concat.append(aatype_one_hot[..., | |||||
| None, :].expand(*aatype_one_hot.shape[:-2], | |||||
| -1, n_res, -1)) | |||||
| n, ca, c = [rc.atom_order[a] for a in ['N', 'CA', 'C']] | |||||
| rigids = Frame.make_transform_from_reference( | |||||
| n_xyz=batch['template_all_atom_positions'][..., n, :], | |||||
| ca_xyz=batch['template_all_atom_positions'][..., ca, :], | |||||
| c_xyz=batch['template_all_atom_positions'][..., c, :], | |||||
| eps=eps, | |||||
| ) | |||||
| points = rigids.get_trans()[..., None, :, :] | |||||
| rigid_vec = rigids[..., None].invert_apply(points) | |||||
| inv_distance_scalar = torch.rsqrt(eps + torch.sum(rigid_vec**2, dim=-1)) | |||||
| t_aa_masks = batch['template_all_atom_mask'] | |||||
| backbone_mask = t_aa_masks[..., n] * t_aa_masks[..., ca] * t_aa_masks[..., | |||||
| c] | |||||
| backbone_mask_2d = backbone_mask[..., :, None] * backbone_mask[..., | |||||
| None, :] | |||||
| if multichain_mask_2d is not None: | |||||
| backbone_mask_2d *= multichain_mask_2d | |||||
| inv_distance_scalar = inv_distance_scalar * backbone_mask_2d | |||||
| unit_vector_data = rigid_vec * inv_distance_scalar[..., None] | |||||
| to_concat.extend(torch.unbind(unit_vector_data[..., None, :], dim=-1)) | |||||
| to_concat.append(backbone_mask_2d[..., None]) | |||||
| return to_concat | |||||
| def build_extra_msa_feat(batch): | |||||
| msa_1hot = one_hot(batch['extra_msa'], 23) | |||||
| msa_feat = [ | |||||
| msa_1hot, | |||||
| batch['extra_msa_has_deletion'].unsqueeze(-1), | |||||
| batch['extra_msa_deletion_value'].unsqueeze(-1), | |||||
| ] | |||||
| return torch.cat(msa_feat, dim=-1) | |||||
| @@ -0,0 +1,562 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from __future__ import annotations # noqa | |||||
| from typing import Any, Callable, Iterable, Optional, Sequence, Tuple | |||||
| import numpy as np | |||||
| import torch | |||||
| def zero_translation( | |||||
| batch_dims: Tuple[int], | |||||
| dtype: Optional[torch.dtype] = torch.float, | |||||
| device: Optional[torch.device] = torch.device('cpu'), | |||||
| requires_grad: bool = False, | |||||
| ) -> torch.Tensor: | |||||
| trans = torch.zeros((*batch_dims, 3), | |||||
| dtype=dtype, | |||||
| device=device, | |||||
| requires_grad=requires_grad) | |||||
| return trans | |||||
| # pylint: disable=bad-whitespace | |||||
| _QUAT_TO_ROT = np.zeros((4, 4, 3, 3), dtype=np.float32) | |||||
| _QUAT_TO_ROT[0, 0] = [[1, 0, 0], [0, 1, 0], [0, 0, 1]] # rr | |||||
| _QUAT_TO_ROT[1, 1] = [[1, 0, 0], [0, -1, 0], [0, 0, -1]] # ii | |||||
| _QUAT_TO_ROT[2, 2] = [[-1, 0, 0], [0, 1, 0], [0, 0, -1]] # jj | |||||
| _QUAT_TO_ROT[3, 3] = [[-1, 0, 0], [0, -1, 0], [0, 0, 1]] # kk | |||||
| _QUAT_TO_ROT[1, 2] = [[0, 2, 0], [2, 0, 0], [0, 0, 0]] # ij | |||||
| _QUAT_TO_ROT[1, 3] = [[0, 0, 2], [0, 0, 0], [2, 0, 0]] # ik | |||||
| _QUAT_TO_ROT[2, 3] = [[0, 0, 0], [0, 0, 2], [0, 2, 0]] # jk | |||||
| _QUAT_TO_ROT[0, 1] = [[0, 0, 0], [0, 0, -2], [0, 2, 0]] # ir | |||||
| _QUAT_TO_ROT[0, 2] = [[0, 0, 2], [0, 0, 0], [-2, 0, 0]] # jr | |||||
| _QUAT_TO_ROT[0, 3] = [[0, -2, 0], [2, 0, 0], [0, 0, 0]] # kr | |||||
| _QUAT_TO_ROT = _QUAT_TO_ROT.reshape(4, 4, 9) | |||||
| _QUAT_TO_ROT_tensor = torch.from_numpy(_QUAT_TO_ROT) | |||||
| _QUAT_MULTIPLY = np.zeros((4, 4, 4)) | |||||
| _QUAT_MULTIPLY[:, :, 0] = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], | |||||
| [0, 0, 0, -1]] | |||||
| _QUAT_MULTIPLY[:, :, 1] = [[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 1], | |||||
| [0, 0, -1, 0]] | |||||
| _QUAT_MULTIPLY[:, :, 2] = [[0, 0, 1, 0], [0, 0, 0, -1], [1, 0, 0, 0], | |||||
| [0, 1, 0, 0]] | |||||
| _QUAT_MULTIPLY[:, :, 3] = [[0, 0, 0, 1], [0, 0, 1, 0], [0, -1, 0, 0], | |||||
| [1, 0, 0, 0]] | |||||
| _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] | |||||
| _QUAT_MULTIPLY_BY_VEC_tensor = torch.from_numpy(_QUAT_MULTIPLY_BY_VEC) | |||||
| class Rotation: | |||||
| def __init__( | |||||
| self, | |||||
| mat: torch.Tensor, | |||||
| ): | |||||
| if mat.shape[-2:] != (3, 3): | |||||
| raise ValueError(f'incorrect rotation shape: {mat.shape}') | |||||
| self._mat = mat | |||||
| @staticmethod | |||||
| def identity( | |||||
| shape, | |||||
| dtype: Optional[torch.dtype] = torch.float, | |||||
| device: Optional[torch.device] = torch.device('cpu'), | |||||
| requires_grad: bool = False, | |||||
| ) -> Rotation: | |||||
| mat = torch.eye( | |||||
| 3, dtype=dtype, device=device, requires_grad=requires_grad) | |||||
| mat = mat.view(*((1, ) * len(shape)), 3, 3) | |||||
| mat = mat.expand(*shape, -1, -1) | |||||
| return Rotation(mat) | |||||
| @staticmethod | |||||
| def mat_mul_mat(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: | |||||
| return (a.float() @ b.float()).type(a.dtype) | |||||
| @staticmethod | |||||
| def mat_mul_vec(r: torch.Tensor, t: torch.Tensor) -> torch.Tensor: | |||||
| return (r.float() @ t.float().unsqueeze(-1)).squeeze(-1).type(t.dtype) | |||||
| def __getitem__(self, index: Any) -> Rotation: | |||||
| if not isinstance(index, tuple): | |||||
| index = (index, ) | |||||
| return Rotation(mat=self._mat[index + (slice(None), slice(None))]) | |||||
| def __mul__(self, right: Any) -> Rotation: | |||||
| if isinstance(right, (int, float)): | |||||
| return Rotation(mat=self._mat * right) | |||||
| elif isinstance(right, torch.Tensor): | |||||
| return Rotation(mat=self._mat * right[..., None, None]) | |||||
| else: | |||||
| raise TypeError( | |||||
| f'multiplicand must be a tensor or a number, got {type(right)}.' | |||||
| ) | |||||
| def __rmul__(self, left: Any) -> Rotation: | |||||
| return self.__mul__(left) | |||||
| def __matmul__(self, other: Rotation) -> Rotation: | |||||
| new_mat = Rotation.mat_mul_mat(self.rot_mat, other.rot_mat) | |||||
| return Rotation(mat=new_mat) | |||||
| @property | |||||
| def _inv_mat(self): | |||||
| return self._mat.transpose(-1, -2) | |||||
| @property | |||||
| def rot_mat(self) -> torch.Tensor: | |||||
| return self._mat | |||||
| def invert(self) -> Rotation: | |||||
| return Rotation(mat=self._inv_mat) | |||||
| def apply(self, pts: torch.Tensor) -> torch.Tensor: | |||||
| return Rotation.mat_mul_vec(self._mat, pts) | |||||
| def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: | |||||
| return Rotation.mat_mul_vec(self._inv_mat, pts) | |||||
| # inherit tensor behaviors | |||||
| @property | |||||
| def shape(self) -> torch.Size: | |||||
| s = self._mat.shape[:-2] | |||||
| return s | |||||
| @property | |||||
| def dtype(self) -> torch.dtype: | |||||
| return self._mat.dtype | |||||
| @property | |||||
| def device(self) -> torch.device: | |||||
| return self._mat.device | |||||
| @property | |||||
| def requires_grad(self) -> bool: | |||||
| return self._mat.requires_grad | |||||
| def unsqueeze(self, dim: int) -> Rotation: | |||||
| if dim >= len(self.shape): | |||||
| raise ValueError('Invalid dimension') | |||||
| rot_mats = self._mat.unsqueeze(dim if dim >= 0 else dim - 2) | |||||
| return Rotation(mat=rot_mats) | |||||
| def map_tensor_fn(self, fn: Callable[[torch.Tensor], | |||||
| torch.Tensor]) -> Rotation: | |||||
| mat = self._mat.view(self._mat.shape[:-2] + (9, )) | |||||
| mat = torch.stack(list(map(fn, torch.unbind(mat, dim=-1))), dim=-1) | |||||
| mat = mat.view(mat.shape[:-1] + (3, 3)) | |||||
| return Rotation(mat=mat) | |||||
| @staticmethod | |||||
| def cat(rs: Sequence[Rotation], dim: int) -> Rotation: | |||||
| rot_mats = [r.rot_mat for r in rs] | |||||
| rot_mats = torch.cat(rot_mats, dim=dim if dim >= 0 else dim - 2) | |||||
| return Rotation(mat=rot_mats) | |||||
| def cuda(self) -> Rotation: | |||||
| return Rotation(mat=self._mat.cuda()) | |||||
| def to(self, device: Optional[torch.device], | |||||
| dtype: Optional[torch.dtype]) -> Rotation: | |||||
| return Rotation(mat=self._mat.to(device=device, dtype=dtype)) | |||||
| def type(self, dtype: Optional[torch.dtype]) -> Rotation: | |||||
| return Rotation(mat=self._mat.type(dtype)) | |||||
| def detach(self) -> Rotation: | |||||
| return Rotation(mat=self._mat.detach()) | |||||
| class Frame: | |||||
| def __init__( | |||||
| self, | |||||
| rotation: Optional[Rotation], | |||||
| translation: Optional[torch.Tensor], | |||||
| ): | |||||
| if rotation is None and translation is None: | |||||
| rotation = Rotation.identity((0, )) | |||||
| translation = zero_translation((0, )) | |||||
| elif translation is None: | |||||
| translation = zero_translation(rotation.shape, rotation.dtype, | |||||
| rotation.device, | |||||
| rotation.requires_grad) | |||||
| elif rotation is None: | |||||
| rotation = Rotation.identity( | |||||
| translation.shape[:-1], | |||||
| translation.dtype, | |||||
| translation.device, | |||||
| translation.requires_grad, | |||||
| ) | |||||
| if (rotation.shape != translation.shape[:-1]) or (rotation.device | |||||
| != # noqa W504 | |||||
| translation.device): | |||||
| raise ValueError('RotationMatrix and translation incompatible') | |||||
| self._r = rotation | |||||
| self._t = translation | |||||
| @staticmethod | |||||
| def identity( | |||||
| shape: Iterable[int], | |||||
| dtype: Optional[torch.dtype] = torch.float, | |||||
| device: Optional[torch.device] = torch.device('cpu'), | |||||
| requires_grad: bool = False, | |||||
| ) -> Frame: | |||||
| return Frame( | |||||
| Rotation.identity(shape, dtype, device, requires_grad), | |||||
| zero_translation(shape, dtype, device, requires_grad), | |||||
| ) | |||||
| def __getitem__( | |||||
| self, | |||||
| index: Any, | |||||
| ) -> Frame: | |||||
| if type(index) != tuple: | |||||
| index = (index, ) | |||||
| return Frame( | |||||
| self._r[index], | |||||
| self._t[index + (slice(None), )], | |||||
| ) | |||||
| def __mul__( | |||||
| self, | |||||
| right: torch.Tensor, | |||||
| ) -> Frame: | |||||
| if not (isinstance(right, torch.Tensor)): | |||||
| raise TypeError('The other multiplicand must be a Tensor') | |||||
| new_rots = self._r * right | |||||
| new_trans = self._t * right[..., None] | |||||
| return Frame(new_rots, new_trans) | |||||
| def __rmul__( | |||||
| self, | |||||
| left: torch.Tensor, | |||||
| ) -> Frame: | |||||
| return self.__mul__(left) | |||||
| @property | |||||
| def shape(self) -> torch.Size: | |||||
| s = self._t.shape[:-1] | |||||
| return s | |||||
| @property | |||||
| def device(self) -> torch.device: | |||||
| return self._t.device | |||||
| def get_rots(self) -> Rotation: | |||||
| return self._r | |||||
| def get_trans(self) -> torch.Tensor: | |||||
| return self._t | |||||
| def compose( | |||||
| self, | |||||
| other: Frame, | |||||
| ) -> Frame: | |||||
| new_rot = self._r @ other._r | |||||
| new_trans = self._r.apply(other._t) + self._t | |||||
| return Frame(new_rot, new_trans) | |||||
| def apply( | |||||
| self, | |||||
| pts: torch.Tensor, | |||||
| ) -> torch.Tensor: | |||||
| rotated = self._r.apply(pts) | |||||
| return rotated + self._t | |||||
| def invert_apply(self, pts: torch.Tensor) -> torch.Tensor: | |||||
| pts = pts - self._t | |||||
| return self._r.invert_apply(pts) | |||||
| def invert(self) -> Frame: | |||||
| rot_inv = self._r.invert() | |||||
| trn_inv = rot_inv.apply(self._t) | |||||
| return Frame(rot_inv, -1 * trn_inv) | |||||
| def map_tensor_fn(self, fn: Callable[[torch.Tensor], | |||||
| torch.Tensor]) -> Frame: | |||||
| new_rots = self._r.map_tensor_fn(fn) | |||||
| new_trans = torch.stack( | |||||
| list(map(fn, torch.unbind(self._t, dim=-1))), dim=-1) | |||||
| return Frame(new_rots, new_trans) | |||||
| def to_tensor_4x4(self) -> torch.Tensor: | |||||
| tensor = self._t.new_zeros((*self.shape, 4, 4)) | |||||
| tensor[..., :3, :3] = self._r.rot_mat | |||||
| tensor[..., :3, 3] = self._t | |||||
| tensor[..., 3, 3] = 1 | |||||
| return tensor | |||||
| @staticmethod | |||||
| def from_tensor_4x4(t: torch.Tensor) -> Frame: | |||||
| if t.shape[-2:] != (4, 4): | |||||
| raise ValueError('Incorrectly shaped input tensor') | |||||
| rots = Rotation(mat=t[..., :3, :3]) | |||||
| trans = t[..., :3, 3] | |||||
| return Frame(rots, trans) | |||||
| @staticmethod | |||||
| def from_3_points( | |||||
| p_neg_x_axis: torch.Tensor, | |||||
| origin: torch.Tensor, | |||||
| p_xy_plane: torch.Tensor, | |||||
| eps: float = 1e-8, | |||||
| ) -> Frame: | |||||
| p_neg_x_axis = torch.unbind(p_neg_x_axis, dim=-1) | |||||
| origin = torch.unbind(origin, dim=-1) | |||||
| p_xy_plane = torch.unbind(p_xy_plane, dim=-1) | |||||
| e0 = [c1 - c2 for c1, c2 in zip(origin, p_neg_x_axis)] | |||||
| e1 = [c1 - c2 for c1, c2 in zip(p_xy_plane, origin)] | |||||
| denom = torch.sqrt(sum((c * c for c in e0)) + eps) | |||||
| e0 = [c / denom for c in e0] | |||||
| dot = sum((c1 * c2 for c1, c2 in zip(e0, e1))) | |||||
| e1 = [c2 - c1 * dot for c1, c2 in zip(e0, e1)] | |||||
| denom = torch.sqrt(sum((c * c for c in e1)) + eps) | |||||
| e1 = [c / denom for c in e1] | |||||
| e2 = [ | |||||
| e0[1] * e1[2] - e0[2] * e1[1], | |||||
| e0[2] * e1[0] - e0[0] * e1[2], | |||||
| e0[0] * e1[1] - e0[1] * e1[0], | |||||
| ] | |||||
| rots = torch.stack([c for tup in zip(e0, e1, e2) for c in tup], dim=-1) | |||||
| rots = rots.reshape(rots.shape[:-1] + (3, 3)) | |||||
| rot_obj = Rotation(mat=rots) | |||||
| return Frame(rot_obj, torch.stack(origin, dim=-1)) | |||||
| def unsqueeze( | |||||
| self, | |||||
| dim: int, | |||||
| ) -> Frame: | |||||
| if dim >= len(self.shape): | |||||
| raise ValueError('Invalid dimension') | |||||
| rots = self._r.unsqueeze(dim) | |||||
| trans = self._t.unsqueeze(dim if dim >= 0 else dim - 1) | |||||
| return Frame(rots, trans) | |||||
| @staticmethod | |||||
| def cat( | |||||
| Ts: Sequence[Frame], | |||||
| dim: int, | |||||
| ) -> Frame: | |||||
| rots = Rotation.cat([T._r for T in Ts], dim) | |||||
| trans = torch.cat([T._t for T in Ts], dim=dim if dim >= 0 else dim - 1) | |||||
| return Frame(rots, trans) | |||||
| def apply_rot_fn(self, fn: Callable[[Rotation], Rotation]) -> Frame: | |||||
| return Frame(fn(self._r), self._t) | |||||
| def apply_trans_fn(self, fn: Callable[[torch.Tensor], | |||||
| torch.Tensor]) -> Frame: | |||||
| return Frame(self._r, fn(self._t)) | |||||
| def scale_translation(self, trans_scale_factor: float) -> Frame: | |||||
| # fn = lambda t: t * trans_scale_factor | |||||
| def fn(t): | |||||
| return t * trans_scale_factor | |||||
| return self.apply_trans_fn(fn) | |||||
| def stop_rot_gradient(self) -> Frame: | |||||
| # fn = lambda r: r.detach() | |||||
| def fn(r): | |||||
| return r.detach() | |||||
| return self.apply_rot_fn(fn) | |||||
| @staticmethod | |||||
| def make_transform_from_reference(n_xyz, ca_xyz, c_xyz, eps=1e-20): | |||||
| input_dtype = ca_xyz.dtype | |||||
| n_xyz = n_xyz.float() | |||||
| ca_xyz = ca_xyz.float() | |||||
| c_xyz = c_xyz.float() | |||||
| n_xyz = n_xyz - ca_xyz | |||||
| c_xyz = c_xyz - ca_xyz | |||||
| c_x, c_y, d_pair = [c_xyz[..., i] for i in range(3)] | |||||
| norm = torch.sqrt(eps + c_x**2 + c_y**2) | |||||
| sin_c1 = -c_y / norm | |||||
| cos_c1 = c_x / norm | |||||
| c1_rots = sin_c1.new_zeros((*sin_c1.shape, 3, 3)) | |||||
| c1_rots[..., 0, 0] = cos_c1 | |||||
| c1_rots[..., 0, 1] = -1 * sin_c1 | |||||
| c1_rots[..., 1, 0] = sin_c1 | |||||
| c1_rots[..., 1, 1] = cos_c1 | |||||
| c1_rots[..., 2, 2] = 1 | |||||
| norm = torch.sqrt(eps + c_x**2 + c_y**2 + d_pair**2) | |||||
| sin_c2 = d_pair / norm | |||||
| cos_c2 = torch.sqrt(c_x**2 + c_y**2) / norm | |||||
| c2_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) | |||||
| c2_rots[..., 0, 0] = cos_c2 | |||||
| c2_rots[..., 0, 2] = sin_c2 | |||||
| c2_rots[..., 1, 1] = 1 | |||||
| c2_rots[..., 2, 0] = -1 * sin_c2 | |||||
| c2_rots[..., 2, 2] = cos_c2 | |||||
| c_rots = Rotation.mat_mul_mat(c2_rots, c1_rots) | |||||
| n_xyz = Rotation.mat_mul_vec(c_rots, n_xyz) | |||||
| _, n_y, n_z = [n_xyz[..., i] for i in range(3)] | |||||
| norm = torch.sqrt(eps + n_y**2 + n_z**2) | |||||
| sin_n = -n_z / norm | |||||
| cos_n = n_y / norm | |||||
| n_rots = sin_c2.new_zeros((*sin_c2.shape, 3, 3)) | |||||
| n_rots[..., 0, 0] = 1 | |||||
| n_rots[..., 1, 1] = cos_n | |||||
| n_rots[..., 1, 2] = -1 * sin_n | |||||
| n_rots[..., 2, 1] = sin_n | |||||
| n_rots[..., 2, 2] = cos_n | |||||
| rots = Rotation.mat_mul_mat(n_rots, c_rots) | |||||
| rots = rots.transpose(-1, -2) | |||||
| rot_obj = Rotation(mat=rots.type(input_dtype)) | |||||
| return Frame(rot_obj, ca_xyz.type(input_dtype)) | |||||
| def cuda(self) -> Frame: | |||||
| return Frame(self._r.cuda(), self._t.cuda()) | |||||
| @property | |||||
| def dtype(self) -> torch.dtype: | |||||
| assert self._r.dtype == self._t.dtype | |||||
| return self._r.dtype | |||||
| def type(self, dtype) -> Frame: | |||||
| return Frame(self._r.type(dtype), self._t.type(dtype)) | |||||
| class Quaternion: | |||||
| def __init__(self, quaternion: torch.Tensor, translation: torch.Tensor): | |||||
| if quaternion.shape[-1] != 4: | |||||
| raise ValueError(f'incorrect quaternion shape: {quaternion.shape}') | |||||
| self._q = quaternion | |||||
| self._t = translation | |||||
| @staticmethod | |||||
| def identity( | |||||
| shape: Iterable[int], | |||||
| dtype: Optional[torch.dtype] = torch.float, | |||||
| device: Optional[torch.device] = torch.device('cpu'), | |||||
| requires_grad: bool = False, | |||||
| ) -> Quaternion: | |||||
| trans = zero_translation(shape, dtype, device, requires_grad) | |||||
| quats = torch.zeros((*shape, 4), | |||||
| dtype=dtype, | |||||
| device=device, | |||||
| requires_grad=requires_grad) | |||||
| with torch.no_grad(): | |||||
| quats[..., 0] = 1 | |||||
| return Quaternion(quats, trans) | |||||
| def get_quats(self): | |||||
| return self._q | |||||
| def get_trans(self): | |||||
| return self._t | |||||
| def get_rot_mats(self): | |||||
| quats = self.get_quats() | |||||
| rot_mats = Quaternion.quat_to_rot(quats) | |||||
| return rot_mats | |||||
| @staticmethod | |||||
| def quat_to_rot(normalized_quat): | |||||
| global _QUAT_TO_ROT_tensor | |||||
| dtype = normalized_quat.dtype | |||||
| normalized_quat = normalized_quat.float() | |||||
| if _QUAT_TO_ROT_tensor.device != normalized_quat.device: | |||||
| _QUAT_TO_ROT_tensor = _QUAT_TO_ROT_tensor.to( | |||||
| normalized_quat.device) | |||||
| rot_tensor = torch.sum( | |||||
| _QUAT_TO_ROT_tensor * normalized_quat[..., :, None, None] | |||||
| * normalized_quat[..., None, :, None], | |||||
| dim=(-3, -2), | |||||
| ) | |||||
| rot_tensor = rot_tensor.type(dtype) | |||||
| rot_tensor = rot_tensor.view(*rot_tensor.shape[:-1], 3, 3) | |||||
| return rot_tensor | |||||
| @staticmethod | |||||
| def normalize_quat(quats): | |||||
| dtype = quats.dtype | |||||
| quats = quats.float() | |||||
| quats = quats / torch.linalg.norm(quats, dim=-1, keepdim=True) | |||||
| quats = quats.type(dtype) | |||||
| return quats | |||||
| @staticmethod | |||||
| def quat_multiply_by_vec(quat, vec): | |||||
| dtype = quat.dtype | |||||
| quat = quat.float() | |||||
| vec = vec.float() | |||||
| global _QUAT_MULTIPLY_BY_VEC_tensor | |||||
| if _QUAT_MULTIPLY_BY_VEC_tensor.device != quat.device: | |||||
| _QUAT_MULTIPLY_BY_VEC_tensor = _QUAT_MULTIPLY_BY_VEC_tensor.to( | |||||
| quat.device) | |||||
| mat = _QUAT_MULTIPLY_BY_VEC_tensor | |||||
| reshaped_mat = mat.view((1, ) * len(quat.shape[:-1]) + mat.shape) | |||||
| return torch.sum( | |||||
| reshaped_mat * quat[..., :, None, None] * vec[..., None, :, None], | |||||
| dim=(-3, -2), | |||||
| ).type(dtype) | |||||
| def compose_q_update_vec(self, | |||||
| q_update_vec: torch.Tensor, | |||||
| normalize_quats: bool = True) -> torch.Tensor: | |||||
| quats = self.get_quats() | |||||
| new_quats = quats + Quaternion.quat_multiply_by_vec( | |||||
| quats, q_update_vec) | |||||
| if normalize_quats: | |||||
| new_quats = Quaternion.normalize_quat(new_quats) | |||||
| return new_quats | |||||
| def compose_update_vec( | |||||
| self, | |||||
| update_vec: torch.Tensor, | |||||
| pre_rot_mat: Rotation, | |||||
| ) -> Quaternion: | |||||
| q_vec, t_vec = update_vec[..., :3], update_vec[..., 3:] | |||||
| new_quats = self.compose_q_update_vec(q_vec) | |||||
| trans_update = pre_rot_mat.apply(t_vec) | |||||
| new_trans = self._t + trans_update | |||||
| return Quaternion(new_quats, new_trans) | |||||
| def stop_rot_gradient(self) -> Quaternion: | |||||
| return Quaternion(self._q.detach(), self._t) | |||||
| @@ -0,0 +1,592 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| import math | |||||
| from typing import Tuple | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.modules import LayerNorm, softmax_dropout | |||||
| from unicore.utils import dict_multimap, one_hot, permute_final_dims | |||||
| from modelscope.models.science.unifold.data.residue_constants import ( | |||||
| restype_atom14_mask, restype_atom14_rigid_group_positions, | |||||
| restype_atom14_to_rigid_group, restype_rigid_group_default_frame) | |||||
| from .attentions import gen_attn_mask | |||||
| from .common import Linear, SimpleModuleList, residual | |||||
| from .frame import Frame, Quaternion, Rotation | |||||
| def ipa_point_weights_init_(weights): | |||||
| with torch.no_grad(): | |||||
| softplus_inverse_1 = 0.541324854612918 | |||||
| weights.fill_(softplus_inverse_1) | |||||
| def torsion_angles_to_frames( | |||||
| frame: Frame, | |||||
| alpha: torch.Tensor, | |||||
| aatype: torch.Tensor, | |||||
| default_frames: torch.Tensor, | |||||
| ): | |||||
| default_frame = Frame.from_tensor_4x4(default_frames[aatype, ...]) | |||||
| bb_rot = alpha.new_zeros((*((1, ) * len(alpha.shape[:-1])), 2)) | |||||
| bb_rot[..., 1] = 1 | |||||
| alpha = torch.cat([bb_rot.expand(*alpha.shape[:-2], -1, -1), alpha], | |||||
| dim=-2) | |||||
| all_rots = alpha.new_zeros(default_frame.get_rots().rot_mat.shape) | |||||
| all_rots[..., 0, 0] = 1 | |||||
| all_rots[..., 1, 1] = alpha[..., 1] | |||||
| all_rots[..., 1, 2] = -alpha[..., 0] | |||||
| all_rots[..., 2, 1:] = alpha | |||||
| all_rots = Frame(Rotation(mat=all_rots), None) | |||||
| all_frames = default_frame.compose(all_rots) | |||||
| chi2_frame_to_frame = all_frames[..., 5] | |||||
| chi3_frame_to_frame = all_frames[..., 6] | |||||
| chi4_frame_to_frame = all_frames[..., 7] | |||||
| chi1_frame_to_bb = all_frames[..., 4] | |||||
| chi2_frame_to_bb = chi1_frame_to_bb.compose(chi2_frame_to_frame) | |||||
| chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) | |||||
| chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) | |||||
| all_frames_to_bb = Frame.cat( | |||||
| [ | |||||
| all_frames[..., :5], | |||||
| chi2_frame_to_bb.unsqueeze(-1), | |||||
| chi3_frame_to_bb.unsqueeze(-1), | |||||
| chi4_frame_to_bb.unsqueeze(-1), | |||||
| ], | |||||
| dim=-1, | |||||
| ) | |||||
| all_frames_to_global = frame[..., None].compose(all_frames_to_bb) | |||||
| return all_frames_to_global | |||||
| def frames_and_literature_positions_to_atom14_pos( | |||||
| frame: Frame, | |||||
| aatype: torch.Tensor, | |||||
| default_frames, | |||||
| group_idx, | |||||
| atom_mask, | |||||
| lit_positions, | |||||
| ): | |||||
| group_mask = group_idx[aatype, ...] | |||||
| group_mask = one_hot( | |||||
| group_mask, | |||||
| num_classes=default_frames.shape[-3], | |||||
| ) | |||||
| t_atoms_to_global = frame[..., None, :] * group_mask | |||||
| t_atoms_to_global = t_atoms_to_global.map_tensor_fn( | |||||
| lambda x: torch.sum(x, dim=-1)) | |||||
| atom_mask = atom_mask[aatype, ...].unsqueeze(-1) | |||||
| lit_positions = lit_positions[aatype, ...] | |||||
| pred_positions = t_atoms_to_global.apply(lit_positions) | |||||
| pred_positions = pred_positions * atom_mask | |||||
| return pred_positions | |||||
| class SideChainAngleResnetIteration(nn.Module): | |||||
| def __init__(self, d_hid): | |||||
| super(SideChainAngleResnetIteration, self).__init__() | |||||
| self.d_hid = d_hid | |||||
| self.linear_1 = Linear(self.d_hid, self.d_hid, init='relu') | |||||
| self.act = nn.GELU() | |||||
| self.linear_2 = Linear(self.d_hid, self.d_hid, init='final') | |||||
| def forward(self, s: torch.Tensor) -> torch.Tensor: | |||||
| x = self.act(s) | |||||
| x = self.linear_1(x) | |||||
| x = self.act(x) | |||||
| x = self.linear_2(x) | |||||
| return residual(s, x, self.training) | |||||
| class SidechainAngleResnet(nn.Module): | |||||
| def __init__(self, d_in, d_hid, num_blocks, num_angles): | |||||
| super(SidechainAngleResnet, self).__init__() | |||||
| self.linear_in = Linear(d_in, d_hid) | |||||
| self.act = nn.GELU() | |||||
| self.linear_initial = Linear(d_in, d_hid) | |||||
| self.layers = SimpleModuleList() | |||||
| for _ in range(num_blocks): | |||||
| self.layers.append(SideChainAngleResnetIteration(d_hid=d_hid)) | |||||
| self.linear_out = Linear(d_hid, num_angles * 2) | |||||
| def forward(self, s: torch.Tensor, | |||||
| initial_s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |||||
| initial_s = self.linear_initial(self.act(initial_s)) | |||||
| s = self.linear_in(self.act(s)) | |||||
| s = s + initial_s | |||||
| for layer in self.layers: | |||||
| s = layer(s) | |||||
| s = self.linear_out(self.act(s)) | |||||
| s = s.view(s.shape[:-1] + (-1, 2)) | |||||
| unnormalized_s = s | |||||
| norm_denom = torch.sqrt( | |||||
| torch.clamp( | |||||
| torch.sum(s.float()**2, dim=-1, keepdim=True), | |||||
| min=1e-12, | |||||
| )) | |||||
| s = s.float() / norm_denom | |||||
| return unnormalized_s, s.type(unnormalized_s.dtype) | |||||
| class InvariantPointAttention(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_single: int, | |||||
| d_pair: int, | |||||
| d_hid: int, | |||||
| num_heads: int, | |||||
| num_qk_points: int, | |||||
| num_v_points: int, | |||||
| separate_kv: bool = False, | |||||
| bias: bool = True, | |||||
| eps: float = 1e-8, | |||||
| ): | |||||
| super(InvariantPointAttention, self).__init__() | |||||
| self.d_hid = d_hid | |||||
| self.num_heads = num_heads | |||||
| self.num_qk_points = num_qk_points | |||||
| self.num_v_points = num_v_points | |||||
| self.eps = eps | |||||
| hc = self.d_hid * self.num_heads | |||||
| self.linear_q = Linear(d_single, hc, bias=bias) | |||||
| self.separate_kv = separate_kv | |||||
| if self.separate_kv: | |||||
| self.linear_k = Linear(d_single, hc, bias=bias) | |||||
| self.linear_v = Linear(d_single, hc, bias=bias) | |||||
| else: | |||||
| self.linear_kv = Linear(d_single, 2 * hc, bias=bias) | |||||
| hpq = self.num_heads * self.num_qk_points * 3 | |||||
| self.linear_q_points = Linear(d_single, hpq) | |||||
| hpk = self.num_heads * self.num_qk_points * 3 | |||||
| hpv = self.num_heads * self.num_v_points * 3 | |||||
| if self.separate_kv: | |||||
| self.linear_k_points = Linear(d_single, hpk) | |||||
| self.linear_v_points = Linear(d_single, hpv) | |||||
| else: | |||||
| hpkv = hpk + hpv | |||||
| self.linear_kv_points = Linear(d_single, hpkv) | |||||
| self.linear_b = Linear(d_pair, self.num_heads) | |||||
| self.head_weights = nn.Parameter(torch.zeros((num_heads))) | |||||
| ipa_point_weights_init_(self.head_weights) | |||||
| concat_out_dim = self.num_heads * ( | |||||
| d_pair + self.d_hid + self.num_v_points * 4) | |||||
| self.linear_out = Linear(concat_out_dim, d_single, init='final') | |||||
| self.softplus = nn.Softplus() | |||||
| def forward( | |||||
| self, | |||||
| s: torch.Tensor, | |||||
| z: torch.Tensor, | |||||
| f: Frame, | |||||
| square_mask: torch.Tensor, | |||||
| ) -> torch.Tensor: | |||||
| q = self.linear_q(s) | |||||
| q = q.view(q.shape[:-1] + (self.num_heads, -1)) | |||||
| if self.separate_kv: | |||||
| k = self.linear_k(s) | |||||
| v = self.linear_v(s) | |||||
| k = k.view(k.shape[:-1] + (self.num_heads, -1)) | |||||
| v = v.view(v.shape[:-1] + (self.num_heads, -1)) | |||||
| else: | |||||
| kv = self.linear_kv(s) | |||||
| kv = kv.view(kv.shape[:-1] + (self.num_heads, -1)) | |||||
| k, v = torch.split(kv, self.d_hid, dim=-1) | |||||
| q_pts = self.linear_q_points(s) | |||||
| def process_points(pts, no_points): | |||||
| shape = pts.shape[:-1] + (pts.shape[-1] // 3, 3) | |||||
| if self.separate_kv: | |||||
| # alphafold-multimer uses different layout | |||||
| pts = pts.view(pts.shape[:-1] | |||||
| + (self.num_heads, no_points * 3)) | |||||
| pts = torch.split(pts, pts.shape[-1] // 3, dim=-1) | |||||
| pts = torch.stack(pts, dim=-1).view(*shape) | |||||
| pts = f[..., None].apply(pts) | |||||
| pts = pts.view(pts.shape[:-2] + (self.num_heads, no_points, 3)) | |||||
| return pts | |||||
| q_pts = process_points(q_pts, self.num_qk_points) | |||||
| if self.separate_kv: | |||||
| k_pts = self.linear_k_points(s) | |||||
| v_pts = self.linear_v_points(s) | |||||
| k_pts = process_points(k_pts, self.num_qk_points) | |||||
| v_pts = process_points(v_pts, self.num_v_points) | |||||
| else: | |||||
| kv_pts = self.linear_kv_points(s) | |||||
| kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1) | |||||
| kv_pts = torch.stack(kv_pts, dim=-1) | |||||
| kv_pts = f[..., None].apply(kv_pts) | |||||
| kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.num_heads, -1, 3)) | |||||
| k_pts, v_pts = torch.split( | |||||
| kv_pts, [self.num_qk_points, self.num_v_points], dim=-2) | |||||
| bias = self.linear_b(z) | |||||
| attn = torch.matmul( | |||||
| permute_final_dims(q, (1, 0, 2)), | |||||
| permute_final_dims(k, (1, 2, 0)), | |||||
| ) | |||||
| if self.training: | |||||
| attn = attn * math.sqrt(1.0 / (3 * self.d_hid)) | |||||
| attn = attn + ( | |||||
| math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) | |||||
| pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) | |||||
| pt_att = pt_att.float()**2 | |||||
| else: | |||||
| attn *= math.sqrt(1.0 / (3 * self.d_hid)) | |||||
| attn += (math.sqrt(1.0 / 3) * permute_final_dims(bias, (2, 0, 1))) | |||||
| pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) | |||||
| pt_att *= pt_att | |||||
| pt_att = pt_att.sum(dim=-1) | |||||
| head_weights = self.softplus(self.head_weights).view( | |||||
| *((1, ) * len(pt_att.shape[:-2]) + (-1, 1))) | |||||
| head_weights = head_weights * math.sqrt( | |||||
| 1.0 / (3 * (self.num_qk_points * 9.0 / 2))) | |||||
| pt_att *= head_weights * (-0.5) | |||||
| pt_att = torch.sum(pt_att, dim=-1) | |||||
| pt_att = permute_final_dims(pt_att, (2, 0, 1)) | |||||
| attn += square_mask | |||||
| attn = softmax_dropout( | |||||
| attn, 0, self.training, bias=pt_att.type(attn.dtype)) | |||||
| del pt_att, q_pts, k_pts, bias | |||||
| o = torch.matmul(attn, v.transpose(-2, -3)).transpose(-2, -3) | |||||
| o = o.contiguous().view(*o.shape[:-2], -1) | |||||
| del q, k, v | |||||
| o_pts = torch.sum( | |||||
| (attn[..., None, :, :, None] | |||||
| * permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]), | |||||
| dim=-2, | |||||
| ) | |||||
| o_pts = permute_final_dims(o_pts, (2, 0, 3, 1)) | |||||
| o_pts = f[..., None, None].invert_apply(o_pts) | |||||
| if self.training: | |||||
| o_pts_norm = torch.sqrt( | |||||
| torch.sum(o_pts.float()**2, dim=-1) + self.eps).type( | |||||
| o_pts.dtype) | |||||
| else: | |||||
| o_pts_norm = torch.sqrt(torch.sum(o_pts**2, dim=-1) | |||||
| + self.eps).type(o_pts.dtype) | |||||
| o_pts_norm = o_pts_norm.view(*o_pts_norm.shape[:-2], -1) | |||||
| o_pts = o_pts.view(*o_pts.shape[:-3], -1, 3) | |||||
| o_pair = torch.matmul(attn.transpose(-2, -3), z) | |||||
| o_pair = o_pair.view(*o_pair.shape[:-2], -1) | |||||
| s = self.linear_out( | |||||
| torch.cat((o, *torch.unbind(o_pts, dim=-1), o_pts_norm, o_pair), | |||||
| dim=-1)) | |||||
| return s | |||||
| class BackboneUpdate(nn.Module): | |||||
| def __init__(self, d_single): | |||||
| super(BackboneUpdate, self).__init__() | |||||
| self.linear = Linear(d_single, 6, init='final') | |||||
| def forward(self, s: torch.Tensor): | |||||
| return self.linear(s) | |||||
| class StructureModuleTransitionLayer(nn.Module): | |||||
| def __init__(self, c): | |||||
| super(StructureModuleTransitionLayer, self).__init__() | |||||
| self.linear_1 = Linear(c, c, init='relu') | |||||
| self.linear_2 = Linear(c, c, init='relu') | |||||
| self.act = nn.GELU() | |||||
| self.linear_3 = Linear(c, c, init='final') | |||||
| def forward(self, s): | |||||
| s_old = s | |||||
| s = self.linear_1(s) | |||||
| s = self.act(s) | |||||
| s = self.linear_2(s) | |||||
| s = self.act(s) | |||||
| s = self.linear_3(s) | |||||
| s = residual(s_old, s, self.training) | |||||
| return s | |||||
| class StructureModuleTransition(nn.Module): | |||||
| def __init__(self, c, num_layers, dropout_rate): | |||||
| super(StructureModuleTransition, self).__init__() | |||||
| self.num_layers = num_layers | |||||
| self.dropout_rate = dropout_rate | |||||
| self.layers = SimpleModuleList() | |||||
| for _ in range(self.num_layers): | |||||
| self.layers.append(StructureModuleTransitionLayer(c)) | |||||
| self.dropout = nn.Dropout(self.dropout_rate) | |||||
| self.layer_norm = LayerNorm(c) | |||||
| def forward(self, s): | |||||
| for layer in self.layers: | |||||
| s = layer(s) | |||||
| s = self.dropout(s) | |||||
| s = self.layer_norm(s) | |||||
| return s | |||||
| class StructureModule(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_single, | |||||
| d_pair, | |||||
| d_ipa, | |||||
| d_angle, | |||||
| num_heads_ipa, | |||||
| num_qk_points, | |||||
| num_v_points, | |||||
| dropout_rate, | |||||
| num_blocks, | |||||
| no_transition_layers, | |||||
| num_resnet_blocks, | |||||
| num_angles, | |||||
| trans_scale_factor, | |||||
| separate_kv, | |||||
| ipa_bias, | |||||
| epsilon, | |||||
| inf, | |||||
| **kwargs, | |||||
| ): | |||||
| super(StructureModule, self).__init__() | |||||
| self.num_blocks = num_blocks | |||||
| self.trans_scale_factor = trans_scale_factor | |||||
| self.default_frames = None | |||||
| self.group_idx = None | |||||
| self.atom_mask = None | |||||
| self.lit_positions = None | |||||
| self.inf = inf | |||||
| self.layer_norm_s = LayerNorm(d_single) | |||||
| self.layer_norm_z = LayerNorm(d_pair) | |||||
| self.linear_in = Linear(d_single, d_single) | |||||
| self.ipa = InvariantPointAttention( | |||||
| d_single, | |||||
| d_pair, | |||||
| d_ipa, | |||||
| num_heads_ipa, | |||||
| num_qk_points, | |||||
| num_v_points, | |||||
| separate_kv=separate_kv, | |||||
| bias=ipa_bias, | |||||
| eps=epsilon, | |||||
| ) | |||||
| self.ipa_dropout = nn.Dropout(dropout_rate) | |||||
| self.layer_norm_ipa = LayerNorm(d_single) | |||||
| self.transition = StructureModuleTransition( | |||||
| d_single, | |||||
| no_transition_layers, | |||||
| dropout_rate, | |||||
| ) | |||||
| self.bb_update = BackboneUpdate(d_single) | |||||
| self.angle_resnet = SidechainAngleResnet( | |||||
| d_single, | |||||
| d_angle, | |||||
| num_resnet_blocks, | |||||
| num_angles, | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| s, | |||||
| z, | |||||
| aatype, | |||||
| mask=None, | |||||
| ): | |||||
| if mask is None: | |||||
| mask = s.new_ones(s.shape[:-1]) | |||||
| # generate square mask | |||||
| square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2) | |||||
| square_mask = gen_attn_mask(square_mask, -self.inf).unsqueeze(-3) | |||||
| s = self.layer_norm_s(s) | |||||
| z = self.layer_norm_z(z) | |||||
| initial_s = s | |||||
| s = self.linear_in(s) | |||||
| quat_encoder = Quaternion.identity( | |||||
| s.shape[:-1], | |||||
| s.dtype, | |||||
| s.device, | |||||
| requires_grad=False, | |||||
| ) | |||||
| backb_to_global = Frame( | |||||
| Rotation(mat=quat_encoder.get_rot_mats(), ), | |||||
| quat_encoder.get_trans(), | |||||
| ) | |||||
| outputs = [] | |||||
| for i in range(self.num_blocks): | |||||
| s = residual(s, self.ipa(s, z, backb_to_global, square_mask), | |||||
| self.training) | |||||
| s = self.ipa_dropout(s) | |||||
| s = self.layer_norm_ipa(s) | |||||
| s = self.transition(s) | |||||
| # update quaternion encoder | |||||
| # use backb_to_global to avoid quat-to-rot conversion | |||||
| quat_encoder = quat_encoder.compose_update_vec( | |||||
| self.bb_update(s), pre_rot_mat=backb_to_global.get_rots()) | |||||
| # initial_s is always used to update the backbone | |||||
| unnormalized_angles, angles = self.angle_resnet(s, initial_s) | |||||
| # convert quaternion to rotation matrix | |||||
| backb_to_global = Frame( | |||||
| Rotation(mat=quat_encoder.get_rot_mats(), ), | |||||
| quat_encoder.get_trans(), | |||||
| ) | |||||
| if i == self.num_blocks - 1: | |||||
| all_frames_to_global = self.torsion_angles_to_frames( | |||||
| backb_to_global.scale_translation(self.trans_scale_factor), | |||||
| angles, | |||||
| aatype, | |||||
| ) | |||||
| pred_positions = self.frames_and_literature_positions_to_atom14_pos( | |||||
| all_frames_to_global, | |||||
| aatype, | |||||
| ) | |||||
| preds = { | |||||
| 'frames': | |||||
| backb_to_global.scale_translation( | |||||
| self.trans_scale_factor).to_tensor_4x4(), | |||||
| 'unnormalized_angles': | |||||
| unnormalized_angles, | |||||
| 'angles': | |||||
| angles, | |||||
| } | |||||
| outputs.append(preds) | |||||
| if i < (self.num_blocks - 1): | |||||
| # stop gradient in iteration | |||||
| quat_encoder = quat_encoder.stop_rot_gradient() | |||||
| backb_to_global = backb_to_global.stop_rot_gradient() | |||||
| outputs = dict_multimap(torch.stack, outputs) | |||||
| outputs['sidechain_frames'] = all_frames_to_global.to_tensor_4x4() | |||||
| outputs['positions'] = pred_positions | |||||
| outputs['single'] = s | |||||
| return outputs | |||||
| def _init_residue_constants(self, float_dtype, device): | |||||
| if self.default_frames is None: | |||||
| self.default_frames = torch.tensor( | |||||
| restype_rigid_group_default_frame, | |||||
| dtype=float_dtype, | |||||
| device=device, | |||||
| requires_grad=False, | |||||
| ) | |||||
| if self.group_idx is None: | |||||
| self.group_idx = torch.tensor( | |||||
| restype_atom14_to_rigid_group, | |||||
| device=device, | |||||
| requires_grad=False, | |||||
| ) | |||||
| if self.atom_mask is None: | |||||
| self.atom_mask = torch.tensor( | |||||
| restype_atom14_mask, | |||||
| dtype=float_dtype, | |||||
| device=device, | |||||
| requires_grad=False, | |||||
| ) | |||||
| if self.lit_positions is None: | |||||
| self.lit_positions = torch.tensor( | |||||
| restype_atom14_rigid_group_positions, | |||||
| dtype=float_dtype, | |||||
| device=device, | |||||
| requires_grad=False, | |||||
| ) | |||||
| def torsion_angles_to_frames(self, frame, alpha, aatype): | |||||
| self._init_residue_constants(alpha.dtype, alpha.device) | |||||
| return torsion_angles_to_frames(frame, alpha, aatype, | |||||
| self.default_frames) | |||||
| def frames_and_literature_positions_to_atom14_pos(self, frame, aatype): | |||||
| self._init_residue_constants(frame.get_rots().dtype, | |||||
| frame.get_rots().device) | |||||
| return frames_and_literature_positions_to_atom14_pos( | |||||
| frame, | |||||
| aatype, | |||||
| self.default_frames, | |||||
| self.group_idx, | |||||
| self.atom_mask, | |||||
| self.lit_positions, | |||||
| ) | |||||
| @@ -0,0 +1,330 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| import math | |||||
| from functools import partial | |||||
| from typing import List, Optional, Tuple | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.modules import LayerNorm | |||||
| from unicore.utils import (checkpoint_sequential, permute_final_dims, | |||||
| tensor_tree_map) | |||||
| from .attentions import (Attention, TriangleAttentionEnding, | |||||
| TriangleAttentionStarting, gen_attn_mask) | |||||
| from .common import (Linear, SimpleModuleList, Transition, | |||||
| bias_dropout_residual, chunk_layer, residual, | |||||
| tri_mul_residual) | |||||
| from .featurization import build_template_pair_feat_v2 | |||||
| from .triangle_multiplication import (TriangleMultiplicationIncoming, | |||||
| TriangleMultiplicationOutgoing) | |||||
| class TemplatePointwiseAttention(nn.Module): | |||||
| def __init__(self, d_template, d_pair, d_hid, num_heads, inf, **kwargs): | |||||
| super(TemplatePointwiseAttention, self).__init__() | |||||
| self.inf = inf | |||||
| self.mha = Attention( | |||||
| d_pair, | |||||
| d_template, | |||||
| d_template, | |||||
| d_hid, | |||||
| num_heads, | |||||
| gating=False, | |||||
| ) | |||||
| def _chunk( | |||||
| self, | |||||
| z: torch.Tensor, | |||||
| t: torch.Tensor, | |||||
| mask: torch.Tensor, | |||||
| chunk_size: int, | |||||
| ) -> torch.Tensor: | |||||
| mha_inputs = { | |||||
| 'q': z, | |||||
| 'k': t, | |||||
| 'v': t, | |||||
| 'mask': mask, | |||||
| } | |||||
| return chunk_layer( | |||||
| self.mha, | |||||
| mha_inputs, | |||||
| chunk_size=chunk_size, | |||||
| num_batch_dims=len(z.shape[:-2]), | |||||
| ) | |||||
| def forward( | |||||
| self, | |||||
| t: torch.Tensor, | |||||
| z: torch.Tensor, | |||||
| template_mask: Optional[torch.Tensor] = None, | |||||
| chunk_size: Optional[int] = None, | |||||
| ) -> torch.Tensor: | |||||
| if template_mask is None: | |||||
| template_mask = t.new_ones(t.shape[:-3]) | |||||
| mask = gen_attn_mask(template_mask, -self.inf)[..., None, None, None, | |||||
| None, :] | |||||
| z = z.unsqueeze(-2) | |||||
| t = permute_final_dims(t, (1, 2, 0, 3)) | |||||
| if chunk_size is not None: | |||||
| z = self._chunk(z, t, mask, chunk_size) | |||||
| else: | |||||
| z = self.mha(z, t, t, mask=mask) | |||||
| z = z.squeeze(-2) | |||||
| return z | |||||
| class TemplateProjection(nn.Module): | |||||
| def __init__(self, d_template, d_pair, **kwargs): | |||||
| super(TemplateProjection, self).__init__() | |||||
| self.d_pair = d_pair | |||||
| self.act = nn.ReLU() | |||||
| self.output_linear = Linear(d_template, d_pair, init='relu') | |||||
| def forward(self, t, z) -> torch.Tensor: | |||||
| if t is None: | |||||
| # handle for non-template case | |||||
| shape = z.shape | |||||
| shape[-1] = self.d_pair | |||||
| t = torch.zeros(shape, dtype=z.dtype, device=z.device) | |||||
| t = self.act(t) | |||||
| z_t = self.output_linear(t) | |||||
| return z_t | |||||
| class TemplatePairStackBlock(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_template: int, | |||||
| d_hid_tri_att: int, | |||||
| d_hid_tri_mul: int, | |||||
| num_heads: int, | |||||
| pair_transition_n: int, | |||||
| dropout_rate: float, | |||||
| tri_attn_first: bool, | |||||
| inf: float, | |||||
| **kwargs, | |||||
| ): | |||||
| super(TemplatePairStackBlock, self).__init__() | |||||
| self.tri_att_start = TriangleAttentionStarting( | |||||
| d_template, | |||||
| d_hid_tri_att, | |||||
| num_heads, | |||||
| ) | |||||
| self.tri_att_end = TriangleAttentionEnding( | |||||
| d_template, | |||||
| d_hid_tri_att, | |||||
| num_heads, | |||||
| ) | |||||
| self.tri_mul_out = TriangleMultiplicationOutgoing( | |||||
| d_template, | |||||
| d_hid_tri_mul, | |||||
| ) | |||||
| self.tri_mul_in = TriangleMultiplicationIncoming( | |||||
| d_template, | |||||
| d_hid_tri_mul, | |||||
| ) | |||||
| self.pair_transition = Transition( | |||||
| d_template, | |||||
| pair_transition_n, | |||||
| ) | |||||
| self.tri_attn_first = tri_attn_first | |||||
| self.dropout = dropout_rate | |||||
| self.row_dropout_share_dim = -3 | |||||
| self.col_dropout_share_dim = -2 | |||||
| def forward( | |||||
| self, | |||||
| s: torch.Tensor, | |||||
| mask: torch.Tensor, | |||||
| tri_start_attn_mask: torch.Tensor, | |||||
| tri_end_attn_mask: torch.Tensor, | |||||
| chunk_size: Optional[int] = None, | |||||
| block_size: Optional[int] = None, | |||||
| ): | |||||
| if self.tri_attn_first: | |||||
| s = bias_dropout_residual( | |||||
| self.tri_att_start, | |||||
| s, | |||||
| self.tri_att_start( | |||||
| s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| ) | |||||
| s = bias_dropout_residual( | |||||
| self.tri_att_end, | |||||
| s, | |||||
| self.tri_att_end( | |||||
| s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), | |||||
| self.col_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| ) | |||||
| s = tri_mul_residual( | |||||
| self.tri_mul_out, | |||||
| s, | |||||
| self.tri_mul_out(s, mask=mask, block_size=block_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| block_size=block_size, | |||||
| ) | |||||
| s = tri_mul_residual( | |||||
| self.tri_mul_in, | |||||
| s, | |||||
| self.tri_mul_in(s, mask=mask, block_size=block_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| block_size=block_size, | |||||
| ) | |||||
| else: | |||||
| s = tri_mul_residual( | |||||
| self.tri_mul_out, | |||||
| s, | |||||
| self.tri_mul_out(s, mask=mask, block_size=block_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| block_size=block_size, | |||||
| ) | |||||
| s = tri_mul_residual( | |||||
| self.tri_mul_in, | |||||
| s, | |||||
| self.tri_mul_in(s, mask=mask, block_size=block_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| block_size=block_size, | |||||
| ) | |||||
| s = bias_dropout_residual( | |||||
| self.tri_att_start, | |||||
| s, | |||||
| self.tri_att_start( | |||||
| s, attn_mask=tri_start_attn_mask, chunk_size=chunk_size), | |||||
| self.row_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| ) | |||||
| s = bias_dropout_residual( | |||||
| self.tri_att_end, | |||||
| s, | |||||
| self.tri_att_end( | |||||
| s, attn_mask=tri_end_attn_mask, chunk_size=chunk_size), | |||||
| self.col_dropout_share_dim, | |||||
| self.dropout, | |||||
| self.training, | |||||
| ) | |||||
| s = residual(s, self.pair_transition( | |||||
| s, | |||||
| chunk_size=chunk_size, | |||||
| ), self.training) | |||||
| return s | |||||
| class TemplatePairStack(nn.Module): | |||||
| def __init__( | |||||
| self, | |||||
| d_template, | |||||
| d_hid_tri_att, | |||||
| d_hid_tri_mul, | |||||
| num_blocks, | |||||
| num_heads, | |||||
| pair_transition_n, | |||||
| dropout_rate, | |||||
| tri_attn_first, | |||||
| inf=1e9, | |||||
| **kwargs, | |||||
| ): | |||||
| super(TemplatePairStack, self).__init__() | |||||
| self.blocks = SimpleModuleList() | |||||
| for _ in range(num_blocks): | |||||
| self.blocks.append( | |||||
| TemplatePairStackBlock( | |||||
| d_template=d_template, | |||||
| d_hid_tri_att=d_hid_tri_att, | |||||
| d_hid_tri_mul=d_hid_tri_mul, | |||||
| num_heads=num_heads, | |||||
| pair_transition_n=pair_transition_n, | |||||
| dropout_rate=dropout_rate, | |||||
| inf=inf, | |||||
| tri_attn_first=tri_attn_first, | |||||
| )) | |||||
| self.layer_norm = LayerNorm(d_template) | |||||
| def forward( | |||||
| self, | |||||
| single_templates: Tuple[torch.Tensor], | |||||
| mask: torch.tensor, | |||||
| tri_start_attn_mask: torch.Tensor, | |||||
| tri_end_attn_mask: torch.Tensor, | |||||
| templ_dim: int, | |||||
| chunk_size: int, | |||||
| block_size: int, | |||||
| return_mean: bool, | |||||
| ): | |||||
| def one_template(i): | |||||
| (s, ) = checkpoint_sequential( | |||||
| functions=[ | |||||
| partial( | |||||
| b, | |||||
| mask=mask, | |||||
| tri_start_attn_mask=tri_start_attn_mask, | |||||
| tri_end_attn_mask=tri_end_attn_mask, | |||||
| chunk_size=chunk_size, | |||||
| block_size=block_size, | |||||
| ) for b in self.blocks | |||||
| ], | |||||
| input=(single_templates[i], ), | |||||
| ) | |||||
| return s | |||||
| n_templ = len(single_templates) | |||||
| if n_templ > 0: | |||||
| new_single_templates = [one_template(0)] | |||||
| if return_mean: | |||||
| t = self.layer_norm(new_single_templates[0]) | |||||
| for i in range(1, n_templ): | |||||
| s = one_template(i) | |||||
| if return_mean: | |||||
| t = residual(t, self.layer_norm(s), self.training) | |||||
| else: | |||||
| new_single_templates.append(s) | |||||
| if return_mean: | |||||
| if n_templ > 0: | |||||
| t /= n_templ | |||||
| else: | |||||
| t = None | |||||
| else: | |||||
| t = torch.cat( | |||||
| [s.unsqueeze(templ_dim) for s in new_single_templates], | |||||
| dim=templ_dim) | |||||
| t = self.layer_norm(t) | |||||
| return t | |||||
| @@ -0,0 +1,158 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| from functools import partialmethod | |||||
| from typing import List, Optional | |||||
| import torch | |||||
| import torch.nn as nn | |||||
| from unicore.modules import LayerNorm | |||||
| from unicore.utils import permute_final_dims | |||||
| from .common import Linear | |||||
| class TriangleMultiplication(nn.Module): | |||||
| def __init__(self, d_pair, d_hid, outgoing=True): | |||||
| super(TriangleMultiplication, self).__init__() | |||||
| self.outgoing = outgoing | |||||
| self.linear_ab_p = Linear(d_pair, d_hid * 2) | |||||
| self.linear_ab_g = Linear(d_pair, d_hid * 2, init='gating') | |||||
| self.linear_g = Linear(d_pair, d_pair, init='gating') | |||||
| self.linear_z = Linear(d_hid, d_pair, init='final') | |||||
| self.layer_norm_in = LayerNorm(d_pair) | |||||
| self.layer_norm_out = LayerNorm(d_hid) | |||||
| self._alphafold_original_mode = False | |||||
| def _chunk_2d( | |||||
| self, | |||||
| z: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None, | |||||
| block_size: int = None, | |||||
| ) -> torch.Tensor: | |||||
| # avoid too small chunk size | |||||
| # block_size = max(block_size, 256) | |||||
| new_z = z.new_zeros(z.shape) | |||||
| dim1 = z.shape[-3] | |||||
| def _slice_linear(z, linear: Linear, a=True): | |||||
| d_hid = linear.bias.shape[0] // 2 | |||||
| index = 0 if a else d_hid | |||||
| p = ( | |||||
| nn.functional.linear(z, linear.weight[index:index + d_hid]) | |||||
| + linear.bias[index:index + d_hid]) | |||||
| return p | |||||
| def _chunk_projection(z, mask, a=True): | |||||
| p = _slice_linear(z, self.linear_ab_p, a) * mask | |||||
| p *= torch.sigmoid(_slice_linear(z, self.linear_ab_g, a)) | |||||
| return p | |||||
| num_chunk = (dim1 + block_size - 1) // block_size | |||||
| for i in range(num_chunk): | |||||
| chunk_start = i * block_size | |||||
| chunk_end = min(chunk_start + block_size, dim1) | |||||
| if self.outgoing: | |||||
| a_chunk = _chunk_projection( | |||||
| z[..., chunk_start:chunk_end, :, :], | |||||
| mask[..., chunk_start:chunk_end, :, :], | |||||
| a=True, | |||||
| ) | |||||
| a_chunk = permute_final_dims(a_chunk, (2, 0, 1)) | |||||
| else: | |||||
| a_chunk = _chunk_projection( | |||||
| z[..., :, chunk_start:chunk_end, :], | |||||
| mask[..., :, chunk_start:chunk_end, :], | |||||
| a=True, | |||||
| ) | |||||
| a_chunk = a_chunk.transpose(-1, -3) | |||||
| for j in range(num_chunk): | |||||
| j_chunk_start = j * block_size | |||||
| j_chunk_end = min(j_chunk_start + block_size, dim1) | |||||
| if self.outgoing: | |||||
| b_chunk = _chunk_projection( | |||||
| z[..., j_chunk_start:j_chunk_end, :, :], | |||||
| mask[..., j_chunk_start:j_chunk_end, :, :], | |||||
| a=False, | |||||
| ) | |||||
| b_chunk = b_chunk.transpose(-1, -3) | |||||
| else: | |||||
| b_chunk = _chunk_projection( | |||||
| z[..., :, j_chunk_start:j_chunk_end, :], | |||||
| mask[..., :, j_chunk_start:j_chunk_end, :], | |||||
| a=False, | |||||
| ) | |||||
| b_chunk = permute_final_dims(b_chunk, (2, 0, 1)) | |||||
| x_chunk = torch.matmul(a_chunk, b_chunk) | |||||
| del b_chunk | |||||
| x_chunk = permute_final_dims(x_chunk, (1, 2, 0)) | |||||
| x_chunk = self.layer_norm_out(x_chunk) | |||||
| x_chunk = self.linear_z(x_chunk) | |||||
| x_chunk *= torch.sigmoid( | |||||
| self.linear_g(z[..., chunk_start:chunk_end, | |||||
| j_chunk_start:j_chunk_end, :])) | |||||
| new_z[..., chunk_start:chunk_end, | |||||
| j_chunk_start:j_chunk_end, :] = x_chunk | |||||
| del x_chunk | |||||
| del a_chunk | |||||
| return new_z | |||||
| def forward( | |||||
| self, | |||||
| z: torch.Tensor, | |||||
| mask: Optional[torch.Tensor] = None, | |||||
| block_size=None, | |||||
| ) -> torch.Tensor: | |||||
| mask = mask.unsqueeze(-1) | |||||
| if not self._alphafold_original_mode: | |||||
| # divided by 1/sqrt(dim) for numerical stability | |||||
| mask = mask * (mask.shape[-2]**-0.5) | |||||
| z = self.layer_norm_in(z) | |||||
| if not self.training and block_size is not None: | |||||
| return self._chunk_2d(z, mask, block_size=block_size) | |||||
| g = nn.functional.linear(z, self.linear_g.weight) | |||||
| if self.training: | |||||
| ab = self.linear_ab_p(z) * mask * torch.sigmoid( | |||||
| self.linear_ab_g(z)) | |||||
| else: | |||||
| ab = self.linear_ab_p(z) | |||||
| ab *= mask | |||||
| ab *= torch.sigmoid(self.linear_ab_g(z)) | |||||
| a, b = torch.chunk(ab, 2, dim=-1) | |||||
| del z, ab | |||||
| if self.outgoing: | |||||
| a = permute_final_dims(a, (2, 0, 1)) | |||||
| b = b.transpose(-1, -3) | |||||
| else: | |||||
| b = permute_final_dims(b, (2, 0, 1)) | |||||
| a = a.transpose(-1, -3) | |||||
| x = torch.matmul(a, b) | |||||
| del a, b | |||||
| x = permute_final_dims(x, (1, 2, 0)) | |||||
| x = self.layer_norm_out(x) | |||||
| x = nn.functional.linear(x, self.linear_z.weight) | |||||
| return x, g | |||||
| def get_output_bias(self): | |||||
| return self.linear_z.bias, self.linear_g.bias | |||||
| class TriangleMultiplicationOutgoing(TriangleMultiplication): | |||||
| __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=True) | |||||
| class TriangleMultiplicationIncoming(TriangleMultiplication): | |||||
| __init__ = partialmethod(TriangleMultiplication.__init__, outgoing=False) | |||||
| @@ -0,0 +1 @@ | |||||
| """ Scripts for MSA & template searching. """ | |||||
| @@ -0,0 +1,483 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Parses the mmCIF file format.""" | |||||
| import collections | |||||
| import dataclasses | |||||
| import functools | |||||
| import io | |||||
| from typing import Any, Mapping, Optional, Sequence, Tuple | |||||
| from absl import logging | |||||
| from Bio import PDB | |||||
| from Bio.Data import SCOPData | |||||
| from Bio.PDB.MMCIFParser import MMCIFParser | |||||
| # Type aliases: | |||||
| ChainId = str | |||||
| PdbHeader = Mapping[str, Any] | |||||
| PdbStructure = PDB.Structure.Structure | |||||
| SeqRes = str | |||||
| MmCIFDict = Mapping[str, Sequence[str]] | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class Monomer: | |||||
| id: str | |||||
| num: int | |||||
| # Note - mmCIF format provides no guarantees on the type of author-assigned | |||||
| # sequence numbers. They need not be integers. | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class AtomSite: | |||||
| residue_name: str | |||||
| author_chain_id: str | |||||
| mmcif_chain_id: str | |||||
| author_seq_num: str | |||||
| mmcif_seq_num: int | |||||
| insertion_code: str | |||||
| hetatm_atom: str | |||||
| model_num: int | |||||
| # Used to map SEQRES index to a residue in the structure. | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class ResiduePosition: | |||||
| chain_id: str | |||||
| residue_number: int | |||||
| insertion_code: str | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class ResidueAtPosition: | |||||
| position: Optional[ResiduePosition] | |||||
| name: str | |||||
| is_missing: bool | |||||
| hetflag: str | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class MmcifObject: | |||||
| """Representation of a parsed mmCIF file. | |||||
| Contains: | |||||
| file_id: A meaningful name, e.g. a pdb_id. Should be unique amongst all | |||||
| files being processed. | |||||
| header: Biopython header. | |||||
| structure: Biopython structure. | |||||
| chain_to_seqres: Dict mapping chain_id to 1 letter amino acid sequence. E.g. | |||||
| {'A': 'ABCDEFG'} | |||||
| seqres_to_structure: Dict; for each chain_id contains a mapping between | |||||
| SEQRES index and a ResidueAtPosition. e.g. {'A': {0: ResidueAtPosition, 1: ResidueAtPosition, ...}} | |||||
| raw_string: The raw string used to construct the MmcifObject. | |||||
| """ | |||||
| file_id: str | |||||
| header: PdbHeader | |||||
| structure: PdbStructure | |||||
| chain_to_seqres: Mapping[ChainId, SeqRes] | |||||
| seqres_to_structure: Mapping[ChainId, Mapping[int, ResidueAtPosition]] | |||||
| raw_string: Any | |||||
| mmcif_to_author_chain_id: Mapping[ChainId, ChainId] | |||||
| valid_chains: Mapping[ChainId, str] | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class ParsingResult: | |||||
| """Returned by the parse function. | |||||
| Contains: | |||||
| mmcif_object: A MmcifObject, may be None if no chain could be successfully | |||||
| parsed. | |||||
| errors: A dict mapping (file_id, chain_id) to any exception generated. | |||||
| """ | |||||
| mmcif_object: Optional[MmcifObject] | |||||
| errors: Mapping[Tuple[str, str], Any] | |||||
| class ParseError(Exception): | |||||
| """An error indicating that an mmCIF file could not be parsed.""" | |||||
| def mmcif_loop_to_list(prefix: str, | |||||
| parsed_info: MmCIFDict) -> Sequence[Mapping[str, str]]: | |||||
| """Extracts loop associated with a prefix from mmCIF data as a list. | |||||
| Reference for loop_ in mmCIF: | |||||
| http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html | |||||
| Args: | |||||
| prefix: Prefix shared by each of the data items in the loop. | |||||
| e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, | |||||
| _entity_poly_seq.mon_id. Should include the trailing period. | |||||
| parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython | |||||
| parser. | |||||
| Returns: | |||||
| Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. | |||||
| """ | |||||
| cols = [] | |||||
| data = [] | |||||
| for key, value in parsed_info.items(): | |||||
| if key.startswith(prefix): | |||||
| cols.append(key) | |||||
| data.append(value) | |||||
| assert all([ | |||||
| len(xs) == len(data[0]) for xs in data | |||||
| ]), ('mmCIF error: Not all loops are the same length: %s' % cols) | |||||
| return [dict(zip(cols, xs)) for xs in zip(*data)] | |||||
| def mmcif_loop_to_dict( | |||||
| prefix: str, | |||||
| index: str, | |||||
| parsed_info: MmCIFDict, | |||||
| ) -> Mapping[str, Mapping[str, str]]: | |||||
| """Extracts loop associated with a prefix from mmCIF data as a dictionary. | |||||
| Args: | |||||
| prefix: Prefix shared by each of the data items in the loop. | |||||
| e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, | |||||
| _entity_poly_seq.mon_id. Should include the trailing period. | |||||
| index: Which item of loop data should serve as the key. | |||||
| parsed_info: A dict of parsed mmCIF data, e.g. _mmcif_dict from a Biopython | |||||
| parser. | |||||
| Returns: | |||||
| Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, | |||||
| indexed by the index column. | |||||
| """ | |||||
| entries = mmcif_loop_to_list(prefix, parsed_info) | |||||
| return {entry[index]: entry for entry in entries} | |||||
| @functools.lru_cache(16, typed=False) | |||||
| def fast_parse(*, | |||||
| file_id: str, | |||||
| mmcif_string: str, | |||||
| catch_all_errors: bool = True) -> ParsingResult: | |||||
| """Entry point, parses an mmcif_string. | |||||
| Args: | |||||
| file_id: A string identifier for this file. Should be unique within the | |||||
| collection of files being processed. | |||||
| mmcif_string: Contents of an mmCIF file. | |||||
| catch_all_errors: If True, all exceptions are caught and error messages are | |||||
| returned as part of the ParsingResult. If False exceptions will be allowed | |||||
| to propagate. | |||||
| Returns: | |||||
| A ParsingResult. | |||||
| """ | |||||
| errors = {} | |||||
| try: | |||||
| parser = MMCIFParser(QUIET=True) | |||||
| # handle = io.StringIO(mmcif_string) | |||||
| # full_structure = parser.get_structure('', handle) | |||||
| parsed_info = parser._mmcif_dict # pylint:disable=protected-access | |||||
| # Ensure all values are lists, even if singletons. | |||||
| for key, value in parsed_info.items(): | |||||
| if not isinstance(value, list): | |||||
| parsed_info[key] = [value] | |||||
| header = _get_header(parsed_info) | |||||
| # Determine the protein chains, and their start numbers according to the | |||||
| # internal mmCIF numbering scheme (likely but not guaranteed to be 1). | |||||
| valid_chains = _get_protein_chains(parsed_info=parsed_info) | |||||
| if not valid_chains: | |||||
| return ParsingResult( | |||||
| None, {(file_id, ''): 'No protein chains found in this file.'}) | |||||
| mmcif_to_author_chain_id = {} | |||||
| # seq_to_structure_mappings = {} | |||||
| for atom in _get_atom_site_list(parsed_info): | |||||
| if atom.model_num != '1': | |||||
| # We only process the first model at the moment. | |||||
| continue | |||||
| mmcif_to_author_chain_id[ | |||||
| atom.mmcif_chain_id] = atom.author_chain_id | |||||
| mmcif_object = MmcifObject( | |||||
| file_id=file_id, | |||||
| header=header, | |||||
| structure=None, | |||||
| chain_to_seqres=None, | |||||
| seqres_to_structure=None, | |||||
| raw_string=parsed_info, | |||||
| mmcif_to_author_chain_id=mmcif_to_author_chain_id, | |||||
| valid_chains=valid_chains, | |||||
| ) | |||||
| return ParsingResult(mmcif_object=mmcif_object, errors=errors) | |||||
| except Exception as e: # pylint:disable=broad-except | |||||
| errors[(file_id, '')] = e | |||||
| if not catch_all_errors: | |||||
| raise | |||||
| return ParsingResult(mmcif_object=None, errors=errors) | |||||
| @functools.lru_cache(16, typed=False) | |||||
| def parse(*, | |||||
| file_id: str, | |||||
| mmcif_string: str, | |||||
| catch_all_errors: bool = True) -> ParsingResult: | |||||
| """Entry point, parses an mmcif_string. | |||||
| Args: | |||||
| file_id: A string identifier for this file. Should be unique within the | |||||
| collection of files being processed. | |||||
| mmcif_string: Contents of an mmCIF file. | |||||
| catch_all_errors: If True, all exceptions are caught and error messages are | |||||
| returned as part of the ParsingResult. If False exceptions will be allowed | |||||
| to propagate. | |||||
| Returns: | |||||
| A ParsingResult. | |||||
| """ | |||||
| errors = {} | |||||
| try: | |||||
| parser = PDB.MMCIFParser(QUIET=True) | |||||
| handle = io.StringIO(mmcif_string) | |||||
| full_structure = parser.get_structure('', handle) | |||||
| first_model_structure = _get_first_model(full_structure) | |||||
| # Extract the _mmcif_dict from the parser, which contains useful fields not | |||||
| # reflected in the Biopython structure. | |||||
| parsed_info = parser._mmcif_dict # pylint:disable=protected-access | |||||
| # Ensure all values are lists, even if singletons. | |||||
| for key, value in parsed_info.items(): | |||||
| if not isinstance(value, list): | |||||
| parsed_info[key] = [value] | |||||
| header = _get_header(parsed_info) | |||||
| # Determine the protein chains, and their start numbers according to the | |||||
| # internal mmCIF numbering scheme (likely but not guaranteed to be 1). | |||||
| valid_chains = _get_protein_chains(parsed_info=parsed_info) | |||||
| if not valid_chains: | |||||
| return ParsingResult( | |||||
| None, {(file_id, ''): 'No protein chains found in this file.'}) | |||||
| seq_start_num = { | |||||
| chain_id: min([monomer.num for monomer in seq]) | |||||
| for chain_id, seq in valid_chains.items() | |||||
| } | |||||
| # Loop over the atoms for which we have coordinates. Populate two mappings: | |||||
| # -mmcif_to_author_chain_id (maps internal mmCIF chain ids to chain ids used | |||||
| # the authors / Biopython). | |||||
| # -seq_to_structure_mappings (maps idx into sequence to ResidueAtPosition). | |||||
| mmcif_to_author_chain_id = {} | |||||
| seq_to_structure_mappings = {} | |||||
| for atom in _get_atom_site_list(parsed_info): | |||||
| if atom.model_num != '1': | |||||
| # We only process the first model at the moment. | |||||
| continue | |||||
| mmcif_to_author_chain_id[ | |||||
| atom.mmcif_chain_id] = atom.author_chain_id | |||||
| if atom.mmcif_chain_id in valid_chains: | |||||
| hetflag = ' ' | |||||
| if atom.hetatm_atom == 'HETATM': | |||||
| # Water atoms are assigned a special hetflag of W in Biopython. We | |||||
| # need to do the same, so that this hetflag can be used to fetch | |||||
| # a residue from the Biopython structure by id. | |||||
| if atom.residue_name in ('HOH', 'WAT'): | |||||
| hetflag = 'W' | |||||
| else: | |||||
| hetflag = 'H_' + atom.residue_name | |||||
| insertion_code = atom.insertion_code | |||||
| if not _is_set(atom.insertion_code): | |||||
| insertion_code = ' ' | |||||
| position = ResiduePosition( | |||||
| chain_id=atom.author_chain_id, | |||||
| residue_number=int(atom.author_seq_num), | |||||
| insertion_code=insertion_code, | |||||
| ) | |||||
| seq_idx = int( | |||||
| atom.mmcif_seq_num) - seq_start_num[atom.mmcif_chain_id] | |||||
| current = seq_to_structure_mappings.get( | |||||
| atom.author_chain_id, {}) | |||||
| current[seq_idx] = ResidueAtPosition( | |||||
| position=position, | |||||
| name=atom.residue_name, | |||||
| is_missing=False, | |||||
| hetflag=hetflag, | |||||
| ) | |||||
| seq_to_structure_mappings[atom.author_chain_id] = current | |||||
| # Add missing residue information to seq_to_structure_mappings. | |||||
| for chain_id, seq_info in valid_chains.items(): | |||||
| author_chain = mmcif_to_author_chain_id[chain_id] | |||||
| current_mapping = seq_to_structure_mappings[author_chain] | |||||
| for idx, monomer in enumerate(seq_info): | |||||
| if idx not in current_mapping: | |||||
| current_mapping[idx] = ResidueAtPosition( | |||||
| position=None, | |||||
| name=monomer.id, | |||||
| is_missing=True, | |||||
| hetflag=' ') | |||||
| author_chain_to_sequence = {} | |||||
| for chain_id, seq_info in valid_chains.items(): | |||||
| author_chain = mmcif_to_author_chain_id[chain_id] | |||||
| seq = [] | |||||
| for monomer in seq_info: | |||||
| code = SCOPData.protein_letters_3to1.get(monomer.id, 'X') | |||||
| seq.append(code if len(code) == 1 else 'X') | |||||
| seq = ''.join(seq) | |||||
| author_chain_to_sequence[author_chain] = seq | |||||
| mmcif_object = MmcifObject( | |||||
| file_id=file_id, | |||||
| header=header, | |||||
| structure=first_model_structure, | |||||
| chain_to_seqres=author_chain_to_sequence, | |||||
| seqres_to_structure=seq_to_structure_mappings, | |||||
| raw_string=parsed_info, | |||||
| mmcif_to_author_chain_id=mmcif_to_author_chain_id, | |||||
| valid_chains=valid_chains, | |||||
| ) | |||||
| return ParsingResult(mmcif_object=mmcif_object, errors=errors) | |||||
| except Exception as e: # pylint:disable=broad-except | |||||
| errors[(file_id, '')] = e | |||||
| if not catch_all_errors: | |||||
| raise | |||||
| return ParsingResult(mmcif_object=None, errors=errors) | |||||
| def _get_first_model(structure: PdbStructure) -> PdbStructure: | |||||
| """Returns the first model in a Biopython structure.""" | |||||
| return next(structure.get_models()) | |||||
| _MIN_LENGTH_OF_CHAIN_TO_BE_COUNTED_AS_PEPTIDE = 21 | |||||
| def get_release_date(parsed_info: MmCIFDict) -> str: | |||||
| """Returns the oldest revision date.""" | |||||
| revision_dates = parsed_info['_pdbx_audit_revision_history.revision_date'] | |||||
| return min(revision_dates) | |||||
| def _get_header(parsed_info: MmCIFDict) -> PdbHeader: | |||||
| """Returns a basic header containing method, release date and resolution.""" | |||||
| header = {} | |||||
| experiments = mmcif_loop_to_list('_exptl.', parsed_info) | |||||
| header['structure_method'] = ','.join( | |||||
| [experiment['_exptl.method'].lower() for experiment in experiments]) | |||||
| # Note: The release_date here corresponds to the oldest revision. We prefer to | |||||
| # use this for dataset filtering over the deposition_date. | |||||
| if '_pdbx_audit_revision_history.revision_date' in parsed_info: | |||||
| header['release_date'] = get_release_date(parsed_info) | |||||
| else: | |||||
| logging.warning('Could not determine release_date: %s', | |||||
| parsed_info['_entry.id']) | |||||
| header['resolution'] = 0.00 | |||||
| for res_key in ( | |||||
| '_refine.ls_d_res_high', | |||||
| '_em_3d_reconstruction.resolution', | |||||
| '_reflns.d_resolution_high', | |||||
| ): | |||||
| if res_key in parsed_info: | |||||
| try: | |||||
| raw_resolution = parsed_info[res_key][0] | |||||
| header['resolution'] = float(raw_resolution) | |||||
| except ValueError: | |||||
| logging.debug('Invalid resolution format: %s', | |||||
| parsed_info[res_key]) | |||||
| return header | |||||
| def _get_atom_site_list(parsed_info: MmCIFDict) -> Sequence[AtomSite]: | |||||
| """Returns list of atom sites; contains data not present in the structure.""" | |||||
| return [ | |||||
| AtomSite(*site) for site in zip( # pylint:disable=g-complex-comprehension | |||||
| parsed_info['_atom_site.label_comp_id'], | |||||
| parsed_info['_atom_site.auth_asym_id'], | |||||
| parsed_info['_atom_site.label_asym_id'], | |||||
| parsed_info['_atom_site.auth_seq_id'], | |||||
| parsed_info['_atom_site.label_seq_id'], | |||||
| parsed_info['_atom_site.pdbx_PDB_ins_code'], | |||||
| parsed_info['_atom_site.group_PDB'], | |||||
| parsed_info['_atom_site.pdbx_PDB_model_num'], | |||||
| ) | |||||
| ] | |||||
| def _get_protein_chains( | |||||
| *, parsed_info: Mapping[str, | |||||
| Any]) -> Mapping[ChainId, Sequence[Monomer]]: | |||||
| """Extracts polymer information for protein chains only. | |||||
| Args: | |||||
| parsed_info: _mmcif_dict produced by the Biopython parser. | |||||
| Returns: | |||||
| A dict mapping mmcif chain id to a list of Monomers. | |||||
| """ | |||||
| # Get polymer information for each entity in the structure. | |||||
| entity_poly_seqs = mmcif_loop_to_list('_entity_poly_seq.', parsed_info) | |||||
| polymers = collections.defaultdict(list) | |||||
| for entity_poly_seq in entity_poly_seqs: | |||||
| polymers[entity_poly_seq['_entity_poly_seq.entity_id']].append( | |||||
| Monomer( | |||||
| id=entity_poly_seq['_entity_poly_seq.mon_id'], | |||||
| num=int(entity_poly_seq['_entity_poly_seq.num']), | |||||
| )) | |||||
| # Get chemical compositions. Will allow us to identify which of these polymers | |||||
| # are proteins. | |||||
| chem_comps = mmcif_loop_to_dict('_chem_comp.', '_chem_comp.id', | |||||
| parsed_info) | |||||
| # Get chains information for each entity. Necessary so that we can return a | |||||
| # dict keyed on chain id rather than entity. | |||||
| struct_asyms = mmcif_loop_to_list('_struct_asym.', parsed_info) | |||||
| entity_to_mmcif_chains = collections.defaultdict(list) | |||||
| for struct_asym in struct_asyms: | |||||
| chain_id = struct_asym['_struct_asym.id'] | |||||
| entity_id = struct_asym['_struct_asym.entity_id'] | |||||
| entity_to_mmcif_chains[entity_id].append(chain_id) | |||||
| # Identify and return the valid protein chains. | |||||
| valid_chains = {} | |||||
| for entity_id, seq_info in polymers.items(): | |||||
| chain_ids = entity_to_mmcif_chains[entity_id] | |||||
| # Reject polymers without any peptide-like components, such as DNA/RNA. | |||||
| if any([ | |||||
| 'peptide' in chem_comps[monomer.id]['_chem_comp.type'] | |||||
| for monomer in seq_info | |||||
| ]): | |||||
| for chain_id in chain_ids: | |||||
| valid_chains[chain_id] = seq_info | |||||
| return valid_chains | |||||
| def _is_set(data: str) -> bool: | |||||
| """Returns False if data is a special mmCIF character indicating 'unset'.""" | |||||
| return data not in ('.', '?') | |||||
| @@ -0,0 +1,88 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Utilities for extracting identifiers from MSA sequence descriptions.""" | |||||
| import dataclasses | |||||
| import re | |||||
| from typing import Optional | |||||
| # Sequences coming from UniProtKB database come in the | |||||
| # `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` | |||||
| # or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). | |||||
| _UNIPROT_PATTERN = re.compile( | |||||
| r""" | |||||
| ^ | |||||
| # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot | |||||
| (?:tr|sp) | |||||
| \| | |||||
| # A primary accession number of the UniProtKB entry. | |||||
| (?P<AccessionIdentifier>[A-Za-z0-9]{6,10}) | |||||
| # Occasionally there is a _0 or _1 isoform suffix, which we ignore. | |||||
| (?:_\d)? | |||||
| \| | |||||
| # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic | |||||
| # protein ID code. | |||||
| (?:[A-Za-z0-9]+) | |||||
| _ | |||||
| # A mnemonic species identification code. | |||||
| (?P<SpeciesIdentifier>([A-Za-z0-9]){1,5}) | |||||
| # Small BFD uses a final value after an underscore, which we ignore. | |||||
| (?:_\d+)? | |||||
| $ | |||||
| """, | |||||
| re.VERBOSE, | |||||
| ) | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class Identifiers: | |||||
| species_id: str = '' | |||||
| def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: | |||||
| """Gets accession id and species from an msa sequence identifier. | |||||
| The sequence identifier has the format specified by | |||||
| _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. | |||||
| An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` | |||||
| Args: | |||||
| msa_sequence_identifier: a sequence identifier. | |||||
| Returns: | |||||
| An `Identifiers` instance with a species_id. These | |||||
| can be empty in the case where no identifier was found. | |||||
| """ | |||||
| matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) | |||||
| if matches: | |||||
| return Identifiers(species_id=matches.group('SpeciesIdentifier')) | |||||
| return Identifiers() | |||||
| def _extract_sequence_identifier(description: str) -> Optional[str]: | |||||
| """Extracts sequence identifier from description. Returns None if no match.""" | |||||
| split_description = description.split() | |||||
| if split_description: | |||||
| return split_description[0].partition('/')[0] | |||||
| else: | |||||
| return None | |||||
| def get_identifiers(description: str) -> Identifiers: | |||||
| """Computes extra MSA features from the description.""" | |||||
| sequence_identifier = _extract_sequence_identifier(description) | |||||
| if sequence_identifier is None: | |||||
| return Identifiers() | |||||
| else: | |||||
| return _parse_sequence_identifier(sequence_identifier) | |||||
| @@ -0,0 +1,627 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Functions for parsing various file formats.""" | |||||
| import collections | |||||
| import dataclasses | |||||
| import itertools | |||||
| import re | |||||
| import string | |||||
| from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple | |||||
| DeletionMatrix = Sequence[Sequence[int]] | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class Msa: | |||||
| """Class representing a parsed MSA file.""" | |||||
| sequences: Sequence[str] | |||||
| deletion_matrix: DeletionMatrix | |||||
| descriptions: Sequence[str] | |||||
| def __post_init__(self): | |||||
| if not (len(self.sequences) == len(self.deletion_matrix) == len( | |||||
| self.descriptions)): | |||||
| raise ValueError( | |||||
| 'All fields for an MSA must have the same length. ' | |||||
| f'Got {len(self.sequences)} sequences, ' | |||||
| f'{len(self.deletion_matrix)} rows in the deletion matrix and ' | |||||
| f'{len(self.descriptions)} descriptions.') | |||||
| def __len__(self): | |||||
| return len(self.sequences) | |||||
| def truncate(self, max_seqs: int): | |||||
| return Msa( | |||||
| sequences=self.sequences[:max_seqs], | |||||
| deletion_matrix=self.deletion_matrix[:max_seqs], | |||||
| descriptions=self.descriptions[:max_seqs], | |||||
| ) | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class TemplateHit: | |||||
| """Class representing a template hit.""" | |||||
| index: int | |||||
| name: str | |||||
| aligned_cols: int | |||||
| sum_probs: Optional[float] | |||||
| query: str | |||||
| hit_sequence: str | |||||
| indices_query: List[int] | |||||
| indices_hit: List[int] | |||||
| def parse_fasta(fasta_string: str) -> Tuple[Sequence[str], Sequence[str]]: | |||||
| """Parses FASTA string and returns list of strings with amino-acid sequences. | |||||
| Arguments: | |||||
| fasta_string: The string contents of a FASTA file. | |||||
| Returns: | |||||
| A tuple of two lists: | |||||
| * A list of sequences. | |||||
| * A list of sequence descriptions taken from the comment lines. In the | |||||
| same order as the sequences. | |||||
| """ | |||||
| sequences = [] | |||||
| descriptions = [] | |||||
| index = -1 | |||||
| for line in fasta_string.splitlines(): | |||||
| line = line.strip() | |||||
| if line.startswith('>'): | |||||
| index += 1 | |||||
| descriptions.append(line[1:]) # Remove the '>' at the beginning. | |||||
| sequences.append('') | |||||
| continue | |||||
| elif not line: | |||||
| continue # Skip blank lines. | |||||
| sequences[index] += line | |||||
| return sequences, descriptions | |||||
| def parse_stockholm(stockholm_string: str) -> Msa: | |||||
| """Parses sequences and deletion matrix from stockholm format alignment. | |||||
| Args: | |||||
| stockholm_string: The string contents of a stockholm file. The first | |||||
| sequence in the file should be the query sequence. | |||||
| Returns: | |||||
| A tuple of: | |||||
| * A list of sequences that have been aligned to the query. These | |||||
| might contain duplicates. | |||||
| * The deletion matrix for the alignment as a list of lists. The element | |||||
| at `deletion_matrix[i][j]` is the number of residues deleted from | |||||
| the aligned sequence i at residue position j. | |||||
| * The names of the targets matched, including the jackhmmer subsequence | |||||
| suffix. | |||||
| """ | |||||
| name_to_sequence = collections.OrderedDict() | |||||
| for line in stockholm_string.splitlines(): | |||||
| line = line.strip() | |||||
| if not line or line.startswith(('#', '//')): | |||||
| continue | |||||
| name, sequence = line.split() | |||||
| if name not in name_to_sequence: | |||||
| name_to_sequence[name] = '' | |||||
| name_to_sequence[name] += sequence | |||||
| msa = [] | |||||
| deletion_matrix = [] | |||||
| query = '' | |||||
| keep_columns = [] | |||||
| for seq_index, sequence in enumerate(name_to_sequence.values()): | |||||
| if seq_index == 0: | |||||
| # Gather the columns with gaps from the query | |||||
| query = sequence | |||||
| keep_columns = [i for i, res in enumerate(query) if res != '-'] | |||||
| # Remove the columns with gaps in the query from all sequences. | |||||
| aligned_sequence = ''.join([sequence[c] for c in keep_columns]) | |||||
| msa.append(aligned_sequence) | |||||
| # Count the number of deletions w.r.t. query. | |||||
| deletion_vec = [] | |||||
| deletion_count = 0 | |||||
| for seq_res, query_res in zip(sequence, query): | |||||
| if seq_res != '-' or query_res != '-': | |||||
| if query_res == '-': | |||||
| deletion_count += 1 | |||||
| else: | |||||
| deletion_vec.append(deletion_count) | |||||
| deletion_count = 0 | |||||
| deletion_matrix.append(deletion_vec) | |||||
| return Msa( | |||||
| sequences=msa, | |||||
| deletion_matrix=deletion_matrix, | |||||
| descriptions=list(name_to_sequence.keys()), | |||||
| ) | |||||
| def parse_a3m(a3m_string: str) -> Msa: | |||||
| """Parses sequences and deletion matrix from a3m format alignment. | |||||
| Args: | |||||
| a3m_string: The string contents of a a3m file. The first sequence in the | |||||
| file should be the query sequence. | |||||
| Returns: | |||||
| A tuple of: | |||||
| * A list of sequences that have been aligned to the query. These | |||||
| might contain duplicates. | |||||
| * The deletion matrix for the alignment as a list of lists. The element | |||||
| at `deletion_matrix[i][j]` is the number of residues deleted from | |||||
| the aligned sequence i at residue position j. | |||||
| * A list of descriptions, one per sequence, from the a3m file. | |||||
| """ | |||||
| sequences, descriptions = parse_fasta(a3m_string) | |||||
| deletion_matrix = [] | |||||
| for msa_sequence in sequences: | |||||
| deletion_vec = [] | |||||
| deletion_count = 0 | |||||
| for j in msa_sequence: | |||||
| if j.islower(): | |||||
| deletion_count += 1 | |||||
| else: | |||||
| deletion_vec.append(deletion_count) | |||||
| deletion_count = 0 | |||||
| deletion_matrix.append(deletion_vec) | |||||
| # Make the MSA matrix out of aligned (deletion-free) sequences. | |||||
| deletion_table = str.maketrans('', '', string.ascii_lowercase) | |||||
| aligned_sequences = [s.translate(deletion_table) for s in sequences] | |||||
| return Msa( | |||||
| sequences=aligned_sequences, | |||||
| deletion_matrix=deletion_matrix, | |||||
| descriptions=descriptions, | |||||
| ) | |||||
| def _convert_sto_seq_to_a3m(query_non_gaps: Sequence[bool], | |||||
| sto_seq: str) -> Iterable[str]: | |||||
| for is_query_res_non_gap, sequence_res in zip(query_non_gaps, sto_seq): | |||||
| if is_query_res_non_gap: | |||||
| yield sequence_res | |||||
| elif sequence_res != '-': | |||||
| yield sequence_res.lower() | |||||
| def convert_stockholm_to_a3m( | |||||
| stockholm_format: str, | |||||
| max_sequences: Optional[int] = None, | |||||
| remove_first_row_gaps: bool = True, | |||||
| ) -> str: | |||||
| """Converts MSA in Stockholm format to the A3M format.""" | |||||
| descriptions = {} | |||||
| sequences = {} | |||||
| reached_max_sequences = False | |||||
| for line in stockholm_format.splitlines(): | |||||
| reached_max_sequences = max_sequences and len( | |||||
| sequences) >= max_sequences | |||||
| if line.strip() and not line.startswith(('#', '//')): | |||||
| # Ignore blank lines, markup and end symbols - remainder are alignment | |||||
| # sequence parts. | |||||
| seqname, aligned_seq = line.split(maxsplit=1) | |||||
| if seqname not in sequences: | |||||
| if reached_max_sequences: | |||||
| continue | |||||
| sequences[seqname] = '' | |||||
| sequences[seqname] += aligned_seq | |||||
| for line in stockholm_format.splitlines(): | |||||
| if line[:4] == '#=GS': | |||||
| # Description row - example format is: | |||||
| # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... | |||||
| columns = line.split(maxsplit=3) | |||||
| seqname, feature = columns[1:3] | |||||
| value = columns[3] if len(columns) == 4 else '' | |||||
| if feature != 'DE': | |||||
| continue | |||||
| if reached_max_sequences and seqname not in sequences: | |||||
| continue | |||||
| descriptions[seqname] = value | |||||
| if len(descriptions) == len(sequences): | |||||
| break | |||||
| # Convert sto format to a3m line by line | |||||
| a3m_sequences = {} | |||||
| if remove_first_row_gaps: | |||||
| # query_sequence is assumed to be the first sequence | |||||
| query_sequence = next(iter(sequences.values())) | |||||
| query_non_gaps = [res != '-' for res in query_sequence] | |||||
| for seqname, sto_sequence in sequences.items(): | |||||
| # Dots are optional in a3m format and are commonly removed. | |||||
| out_sequence = sto_sequence.replace('.', '') | |||||
| if remove_first_row_gaps: | |||||
| out_sequence = ''.join( | |||||
| _convert_sto_seq_to_a3m(query_non_gaps, out_sequence)) | |||||
| a3m_sequences[seqname] = out_sequence | |||||
| fasta_chunks = (f">{k} {descriptions.get(k, '')}\n{a3m_sequences[k]}" | |||||
| for k in a3m_sequences) | |||||
| return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. | |||||
| def _keep_line(line: str, seqnames: Set[str]) -> bool: | |||||
| """Function to decide which lines to keep.""" | |||||
| if not line.strip(): | |||||
| return True | |||||
| if line.strip() == '//': # End tag | |||||
| return True | |||||
| if line.startswith('# STOCKHOLM'): # Start tag | |||||
| return True | |||||
| if line.startswith('#=GC RF'): # Reference Annotation Line | |||||
| return True | |||||
| if line[:4] == '#=GS': # Description lines - keep if sequence in list. | |||||
| _, seqname, _ = line.split(maxsplit=2) | |||||
| return seqname in seqnames | |||||
| elif line.startswith('#'): # Other markup - filter out | |||||
| return False | |||||
| else: # Alignment data - keep if sequence in list. | |||||
| seqname = line.partition(' ')[0] | |||||
| return seqname in seqnames | |||||
| def truncate_stockholm_msa(stockholm_msa: str, max_sequences: int) -> str: | |||||
| """Truncates a stockholm file to a maximum number of sequences.""" | |||||
| seqnames = set() | |||||
| filtered_lines = [] | |||||
| for line in stockholm_msa.splitlines(): | |||||
| if line.strip() and not line.startswith(('#', '//')): | |||||
| # Ignore blank lines, markup and end symbols - remainder are alignment | |||||
| # sequence parts. | |||||
| seqname = line.partition(' ')[0] | |||||
| seqnames.add(seqname) | |||||
| if len(seqnames) >= max_sequences: | |||||
| break | |||||
| for line in stockholm_msa.splitlines(): | |||||
| if _keep_line(line, seqnames): | |||||
| filtered_lines.append(line) | |||||
| return '\n'.join(filtered_lines) + '\n' | |||||
| def remove_empty_columns_from_stockholm_msa(stockholm_msa: str) -> str: | |||||
| """Removes empty columns (dashes-only) from a Stockholm MSA.""" | |||||
| processed_lines = {} | |||||
| unprocessed_lines = {} | |||||
| for i, line in enumerate(stockholm_msa.splitlines()): | |||||
| if line.startswith('#=GC RF'): | |||||
| reference_annotation_i = i | |||||
| reference_annotation_line = line | |||||
| # Reached the end of this chunk of the alignment. Process chunk. | |||||
| _, _, first_alignment = line.rpartition(' ') | |||||
| mask = [] | |||||
| for j in range(len(first_alignment)): | |||||
| for _, unprocessed_line in unprocessed_lines.items(): | |||||
| prefix, _, alignment = unprocessed_line.rpartition(' ') | |||||
| if alignment[j] != '-': | |||||
| mask.append(True) | |||||
| break | |||||
| else: # Every row contained a hyphen - empty column. | |||||
| mask.append(False) | |||||
| # Add reference annotation for processing with mask. | |||||
| unprocessed_lines[ | |||||
| reference_annotation_i] = reference_annotation_line | |||||
| if not any( | |||||
| mask | |||||
| ): # All columns were empty. Output empty lines for chunk. | |||||
| for line_index in unprocessed_lines: | |||||
| processed_lines[line_index] = '' | |||||
| else: | |||||
| for line_index, unprocessed_line in unprocessed_lines.items(): | |||||
| prefix, _, alignment = unprocessed_line.rpartition(' ') | |||||
| masked_alignment = ''.join( | |||||
| itertools.compress(alignment, mask)) | |||||
| processed_lines[ | |||||
| line_index] = f'{prefix} {masked_alignment}' | |||||
| # Clear raw_alignments. | |||||
| unprocessed_lines = {} | |||||
| elif line.strip() and not line.startswith(('#', '//')): | |||||
| unprocessed_lines[i] = line | |||||
| else: | |||||
| processed_lines[i] = line | |||||
| return '\n'.join((processed_lines[i] for i in range(len(processed_lines)))) | |||||
| def deduplicate_stockholm_msa(stockholm_msa: str) -> str: | |||||
| """Remove duplicate sequences (ignoring insertions wrt query).""" | |||||
| sequence_dict = collections.defaultdict(str) | |||||
| # First we must extract all sequences from the MSA. | |||||
| for line in stockholm_msa.splitlines(): | |||||
| # Only consider the alignments - ignore reference annotation, empty lines, | |||||
| # descriptions or markup. | |||||
| if line.strip() and not line.startswith(('#', '//')): | |||||
| line = line.strip() | |||||
| seqname, alignment = line.split() | |||||
| sequence_dict[seqname] += alignment | |||||
| seen_sequences = set() | |||||
| seqnames = set() | |||||
| # First alignment is the query. | |||||
| query_align = next(iter(sequence_dict.values())) | |||||
| mask = [c != '-' for c in query_align] # Mask is False for insertions. | |||||
| for seqname, alignment in sequence_dict.items(): | |||||
| # Apply mask to remove all insertions from the string. | |||||
| masked_alignment = ''.join(itertools.compress(alignment, mask)) | |||||
| if masked_alignment in seen_sequences: | |||||
| continue | |||||
| else: | |||||
| seen_sequences.add(masked_alignment) | |||||
| seqnames.add(seqname) | |||||
| filtered_lines = [] | |||||
| for line in stockholm_msa.splitlines(): | |||||
| if _keep_line(line, seqnames): | |||||
| filtered_lines.append(line) | |||||
| return '\n'.join(filtered_lines) + '\n' | |||||
| def _get_hhr_line_regex_groups(regex_pattern: str, | |||||
| line: str) -> Sequence[Optional[str]]: | |||||
| match = re.match(regex_pattern, line) | |||||
| if match is None: | |||||
| raise RuntimeError(f'Could not parse query line {line}') | |||||
| return match.groups() | |||||
| def _update_hhr_residue_indices_list(sequence: str, start_index: int, | |||||
| indices_list: List[int]): | |||||
| """Computes the relative indices for each residue with respect to the original sequence.""" | |||||
| counter = start_index | |||||
| for symbol in sequence: | |||||
| if symbol == '-': | |||||
| indices_list.append(-1) | |||||
| else: | |||||
| indices_list.append(counter) | |||||
| counter += 1 | |||||
| def _parse_hhr_hit(detailed_lines: Sequence[str]) -> TemplateHit: | |||||
| """Parses the detailed HMM HMM comparison section for a single Hit. | |||||
| This works on .hhr files generated from both HHBlits and HHSearch. | |||||
| Args: | |||||
| detailed_lines: A list of lines from a single comparison section between 2 | |||||
| sequences (which each have their own HMM's) | |||||
| Returns: | |||||
| A dictionary with the information from that detailed comparison section | |||||
| Raises: | |||||
| RuntimeError: If a certain line cannot be processed | |||||
| """ | |||||
| # Parse first 2 lines. | |||||
| number_of_hit = int(detailed_lines[0].split()[-1]) | |||||
| name_hit = detailed_lines[1][1:] | |||||
| # Parse the summary line. | |||||
| pattern = ( | |||||
| 'Probab=(.*)[\t ]*E-value=(.*)[\t ]*Score=(.*)[\t ]*Aligned_cols=(.*)[\t' | |||||
| ' ]*Identities=(.*)%[\t ]*Similarity=(.*)[\t ]*Sum_probs=(.*)[\t ' | |||||
| ']*Template_Neff=(.*)') | |||||
| match = re.match(pattern, detailed_lines[2]) | |||||
| if match is None: | |||||
| raise RuntimeError( | |||||
| 'Could not parse section: %s. Expected this: \n%s to contain summary.' | |||||
| % (detailed_lines, detailed_lines[2])) | |||||
| (_, _, _, aligned_cols, _, _, sum_probs, | |||||
| _) = [float(x) for x in match.groups()] | |||||
| # The next section reads the detailed comparisons. These are in a 'human | |||||
| # readable' format which has a fixed length. The strategy employed is to | |||||
| # assume that each block starts with the query sequence line, and to parse | |||||
| # that with a regexp in order to deduce the fixed length used for that block. | |||||
| query = '' | |||||
| hit_sequence = '' | |||||
| indices_query = [] | |||||
| indices_hit = [] | |||||
| length_block = None | |||||
| for line in detailed_lines[3:]: | |||||
| # Parse the query sequence line | |||||
| if (line.startswith('Q ') and not line.startswith('Q ss_dssp') | |||||
| and not line.startswith('Q ss_pred') | |||||
| and not line.startswith('Q Consensus')): | |||||
| # Thus the first 17 characters must be 'Q <query_name> ', and we can parse | |||||
| # everything after that. | |||||
| # start sequence end total_sequence_length | |||||
| patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*([0-9]*) \([0-9]*\)' | |||||
| groups = _get_hhr_line_regex_groups(patt, line[17:]) | |||||
| # Get the length of the parsed block using the start and finish indices, | |||||
| # and ensure it is the same as the actual block length. | |||||
| start = int(groups[0]) - 1 # Make index zero based. | |||||
| delta_query = groups[1] | |||||
| end = int(groups[2]) | |||||
| num_insertions = len([x for x in delta_query if x == '-']) | |||||
| length_block = end - start + num_insertions | |||||
| assert length_block == len(delta_query) | |||||
| # Update the query sequence and indices list. | |||||
| query += delta_query | |||||
| _update_hhr_residue_indices_list(delta_query, start, indices_query) | |||||
| elif line.startswith('T '): | |||||
| # Parse the hit sequence. | |||||
| if (not line.startswith('T ss_dssp') | |||||
| and not line.startswith('T ss_pred') | |||||
| and not line.startswith('T Consensus')): | |||||
| # Thus the first 17 characters must be 'T <hit_name> ', and we can | |||||
| # parse everything after that. | |||||
| # start sequence end total_sequence_length | |||||
| patt = r'[\t ]*([0-9]*) ([A-Z-]*)[\t ]*[0-9]* \([0-9]*\)' | |||||
| groups = _get_hhr_line_regex_groups(patt, line[17:]) | |||||
| start = int(groups[0]) - 1 # Make index zero based. | |||||
| delta_hit_sequence = groups[1] | |||||
| assert length_block == len(delta_hit_sequence) | |||||
| # Update the hit sequence and indices list. | |||||
| hit_sequence += delta_hit_sequence | |||||
| _update_hhr_residue_indices_list(delta_hit_sequence, start, | |||||
| indices_hit) | |||||
| return TemplateHit( | |||||
| index=number_of_hit, | |||||
| name=name_hit, | |||||
| aligned_cols=int(aligned_cols), | |||||
| sum_probs=sum_probs, | |||||
| query=query, | |||||
| hit_sequence=hit_sequence, | |||||
| indices_query=indices_query, | |||||
| indices_hit=indices_hit, | |||||
| ) | |||||
| def parse_hhr(hhr_string: str) -> Sequence[TemplateHit]: | |||||
| """Parses the content of an entire HHR file.""" | |||||
| lines = hhr_string.splitlines() | |||||
| # Each .hhr file starts with a results table, then has a sequence of hit | |||||
| # "paragraphs", each paragraph starting with a line 'No <hit number>'. We | |||||
| # iterate through each paragraph to parse each hit. | |||||
| block_starts = [ | |||||
| i for i, line in enumerate(lines) if line.startswith('No ') | |||||
| ] | |||||
| hits = [] | |||||
| if block_starts: | |||||
| block_starts.append(len(lines)) # Add the end of the final block. | |||||
| for i in range(len(block_starts) - 1): | |||||
| hits.append( | |||||
| _parse_hhr_hit(lines[block_starts[i]:block_starts[i + 1]])) | |||||
| return hits | |||||
| def parse_e_values_from_tblout(tblout: str) -> Dict[str, float]: | |||||
| """Parse target to e-value mapping parsed from Jackhmmer tblout string.""" | |||||
| e_values = {'query': 0} | |||||
| lines = [line for line in tblout.splitlines() if line[0] != '#'] | |||||
| # As per http://eddylab.org/software/hmmer/Userguide.pdf fields are | |||||
| # space-delimited. Relevant fields are (1) target name: and | |||||
| # (5) E-value (full sequence) (numbering from 1). | |||||
| for line in lines: | |||||
| fields = line.split() | |||||
| e_value = fields[4] | |||||
| target_name = fields[0] | |||||
| e_values[target_name] = float(e_value) | |||||
| return e_values | |||||
| def _get_indices(sequence: str, start: int) -> List[int]: | |||||
| """Returns indices for non-gap/insert residues starting at the given index.""" | |||||
| indices = [] | |||||
| counter = start | |||||
| for symbol in sequence: | |||||
| # Skip gaps but add a placeholder so that the alignment is preserved. | |||||
| if symbol == '-': | |||||
| indices.append(-1) | |||||
| # Skip deleted residues, but increase the counter. | |||||
| elif symbol.islower(): | |||||
| counter += 1 | |||||
| # Normal aligned residue. Increase the counter and append to indices. | |||||
| else: | |||||
| indices.append(counter) | |||||
| counter += 1 | |||||
| return indices | |||||
| @dataclasses.dataclass(frozen=True) | |||||
| class HitMetadata: | |||||
| pdb_id: str | |||||
| chain: str | |||||
| start: int | |||||
| end: int | |||||
| length: int | |||||
| text: str | |||||
| def _parse_hmmsearch_description(description: str) -> HitMetadata: | |||||
| """Parses the hmmsearch A3M sequence description line.""" | |||||
| # Example 1: >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text | |||||
| # Example 2: >5g3r_A/1-55 [subseq from] mol:protein length:352 | |||||
| match = re.match( | |||||
| r'^>?([a-z0-9]+)_(\w+)/([0-9]+)-([0-9]+).*protein length:([0-9]+) *(.*)$', | |||||
| description.strip(), | |||||
| ) | |||||
| if not match: | |||||
| raise ValueError(f'Could not parse description: "{description}".') | |||||
| return HitMetadata( | |||||
| pdb_id=match[1], | |||||
| chain=match[2], | |||||
| start=int(match[3]), | |||||
| end=int(match[4]), | |||||
| length=int(match[5]), | |||||
| text=match[6], | |||||
| ) | |||||
| def parse_hmmsearch_a3m(query_sequence: str, | |||||
| a3m_string: str, | |||||
| skip_first: bool = True) -> Sequence[TemplateHit]: | |||||
| """Parses an a3m string produced by hmmsearch. | |||||
| Args: | |||||
| query_sequence: The query sequence. | |||||
| a3m_string: The a3m string produced by hmmsearch. | |||||
| skip_first: Whether to skip the first sequence in the a3m string. | |||||
| Returns: | |||||
| A sequence of `TemplateHit` results. | |||||
| """ | |||||
| # Zip the descriptions and MSAs together, skip the first query sequence. | |||||
| parsed_a3m = list(zip(*parse_fasta(a3m_string))) | |||||
| if skip_first: | |||||
| parsed_a3m = parsed_a3m[1:] | |||||
| indices_query = _get_indices(query_sequence, start=0) | |||||
| hits = [] | |||||
| for i, (hit_sequence, hit_description) in enumerate(parsed_a3m, start=1): | |||||
| if 'mol:protein' not in hit_description: | |||||
| continue # Skip non-protein chains. | |||||
| metadata = _parse_hmmsearch_description(hit_description) | |||||
| # Aligned columns are only the match states. | |||||
| aligned_cols = sum([r.isupper() and r != '-' for r in hit_sequence]) | |||||
| indices_hit = _get_indices(hit_sequence, start=metadata.start - 1) | |||||
| hit = TemplateHit( | |||||
| index=i, | |||||
| name=f'{metadata.pdb_id}_{metadata.chain}', | |||||
| aligned_cols=aligned_cols, | |||||
| sum_probs=None, | |||||
| query=query_sequence, | |||||
| hit_sequence=hit_sequence.upper(), | |||||
| indices_query=indices_query, | |||||
| indices_hit=indices_hit, | |||||
| ) | |||||
| hits.append(hit) | |||||
| return hits | |||||
| @@ -0,0 +1,282 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Functions for building the input features for the unifold model.""" | |||||
| import os | |||||
| from typing import Any, Mapping, MutableMapping, Optional, Sequence, Union | |||||
| import numpy as np | |||||
| from absl import logging | |||||
| from modelscope.models.science.unifold.data import residue_constants | |||||
| from modelscope.models.science.unifold.msa import (msa_identifiers, parsers, | |||||
| templates) | |||||
| from modelscope.models.science.unifold.msa.tools import (hhblits, hhsearch, | |||||
| hmmsearch, jackhmmer) | |||||
| FeatureDict = MutableMapping[str, np.ndarray] | |||||
| TemplateSearcher = Union[hhsearch.HHSearch, hmmsearch.Hmmsearch] | |||||
| def make_sequence_features(sequence: str, description: str, | |||||
| num_res: int) -> FeatureDict: | |||||
| """Constructs a feature dict of sequence features.""" | |||||
| features = {} | |||||
| features['aatype'] = residue_constants.sequence_to_onehot( | |||||
| sequence=sequence, | |||||
| mapping=residue_constants.restype_order_with_x, | |||||
| map_unknown_to_x=True, | |||||
| ) | |||||
| features['between_segment_residues'] = np.zeros((num_res, ), | |||||
| dtype=np.int32) | |||||
| features['domain_name'] = np.array([description.encode('utf-8')], | |||||
| dtype=np.object_) | |||||
| features['residue_index'] = np.array(range(num_res), dtype=np.int32) | |||||
| features['seq_length'] = np.array([num_res] * num_res, dtype=np.int32) | |||||
| features['sequence'] = np.array([sequence.encode('utf-8')], | |||||
| dtype=np.object_) | |||||
| return features | |||||
| def make_msa_features(msas: Sequence[parsers.Msa]) -> FeatureDict: | |||||
| """Constructs a feature dict of MSA features.""" | |||||
| if not msas: | |||||
| raise ValueError('At least one MSA must be provided.') | |||||
| int_msa = [] | |||||
| deletion_matrix = [] | |||||
| species_ids = [] | |||||
| seen_sequences = set() | |||||
| for msa_index, msa in enumerate(msas): | |||||
| if not msa: | |||||
| raise ValueError( | |||||
| f'MSA {msa_index} must contain at least one sequence.') | |||||
| for sequence_index, sequence in enumerate(msa.sequences): | |||||
| if sequence in seen_sequences: | |||||
| continue | |||||
| seen_sequences.add(sequence) | |||||
| int_msa.append( | |||||
| [residue_constants.HHBLITS_AA_TO_ID[res] for res in sequence]) | |||||
| deletion_matrix.append(msa.deletion_matrix[sequence_index]) | |||||
| identifiers = msa_identifiers.get_identifiers( | |||||
| msa.descriptions[sequence_index]) | |||||
| species_ids.append(identifiers.species_id.encode('utf-8')) | |||||
| num_res = len(msas[0].sequences[0]) | |||||
| num_alignments = len(int_msa) | |||||
| features = {} | |||||
| features['deletion_matrix_int'] = np.array(deletion_matrix, dtype=np.int32) | |||||
| features['msa'] = np.array(int_msa, dtype=np.int32) | |||||
| features['num_alignments'] = np.array( | |||||
| [num_alignments] * num_res, dtype=np.int32) | |||||
| features['msa_species_identifiers'] = np.array( | |||||
| species_ids, dtype=np.object_) | |||||
| return features | |||||
| def run_msa_tool( | |||||
| msa_runner, | |||||
| input_fasta_path: str, | |||||
| msa_out_path: str, | |||||
| msa_format: str, | |||||
| use_precomputed_msas: bool, | |||||
| ) -> Mapping[str, Any]: | |||||
| """Runs an MSA tool, checking if output already exists first.""" | |||||
| if not use_precomputed_msas or not os.path.exists(msa_out_path): | |||||
| result = msa_runner.query(input_fasta_path)[0] | |||||
| with open(msa_out_path, 'w') as f: | |||||
| f.write(result[msa_format]) | |||||
| else: | |||||
| logging.warning('Reading MSA from file %s', msa_out_path) | |||||
| with open(msa_out_path, 'r') as f: | |||||
| result = {msa_format: f.read()} | |||||
| return result | |||||
| class DataPipeline: | |||||
| """Runs the alignment tools and assembles the input features.""" | |||||
| def __init__( | |||||
| self, | |||||
| jackhmmer_binary_path: str, | |||||
| hhblits_binary_path: str, | |||||
| uniref90_database_path: str, | |||||
| mgnify_database_path: str, | |||||
| bfd_database_path: Optional[str], | |||||
| uniclust30_database_path: Optional[str], | |||||
| small_bfd_database_path: Optional[str], | |||||
| uniprot_database_path: Optional[str], | |||||
| template_searcher: TemplateSearcher, | |||||
| template_featurizer: templates.TemplateHitFeaturizer, | |||||
| use_small_bfd: bool, | |||||
| mgnify_max_hits: int = 501, | |||||
| uniref_max_hits: int = 10000, | |||||
| use_precomputed_msas: bool = False, | |||||
| ): | |||||
| """Initializes the data pipeline.""" | |||||
| self._use_small_bfd = use_small_bfd | |||||
| self.jackhmmer_uniref90_runner = jackhmmer.Jackhmmer( | |||||
| binary_path=jackhmmer_binary_path, | |||||
| database_path=uniref90_database_path) | |||||
| if use_small_bfd: | |||||
| self.jackhmmer_small_bfd_runner = jackhmmer.Jackhmmer( | |||||
| binary_path=jackhmmer_binary_path, | |||||
| database_path=small_bfd_database_path) | |||||
| else: | |||||
| self.hhblits_bfd_uniclust_runner = hhblits.HHBlits( | |||||
| binary_path=hhblits_binary_path, | |||||
| databases=[bfd_database_path, uniclust30_database_path], | |||||
| ) | |||||
| self.jackhmmer_mgnify_runner = jackhmmer.Jackhmmer( | |||||
| binary_path=jackhmmer_binary_path, | |||||
| database_path=mgnify_database_path) | |||||
| self.jackhmmer_uniprot_runner = jackhmmer.Jackhmmer( | |||||
| binary_path=jackhmmer_binary_path, | |||||
| database_path=uniprot_database_path) | |||||
| self.template_searcher = template_searcher | |||||
| self.template_featurizer = template_featurizer | |||||
| self.mgnify_max_hits = mgnify_max_hits | |||||
| self.uniref_max_hits = uniref_max_hits | |||||
| self.use_precomputed_msas = use_precomputed_msas | |||||
| def process(self, input_fasta_path: str, | |||||
| msa_output_dir: str) -> FeatureDict: | |||||
| """Runs alignment tools on the input sequence and creates features.""" | |||||
| with open(input_fasta_path) as f: | |||||
| input_fasta_str = f.read() | |||||
| input_seqs, input_descs = parsers.parse_fasta(input_fasta_str) | |||||
| if len(input_seqs) != 1: | |||||
| raise ValueError( | |||||
| f'More than one input sequence found in {input_fasta_path}.') | |||||
| input_sequence = input_seqs[0] | |||||
| input_description = input_descs[0] | |||||
| num_res = len(input_sequence) | |||||
| uniref90_out_path = os.path.join(msa_output_dir, 'uniref90_hits.sto') | |||||
| jackhmmer_uniref90_result = run_msa_tool( | |||||
| self.jackhmmer_uniref90_runner, | |||||
| input_fasta_path, | |||||
| uniref90_out_path, | |||||
| 'sto', | |||||
| self.use_precomputed_msas, | |||||
| ) | |||||
| mgnify_out_path = os.path.join(msa_output_dir, 'mgnify_hits.sto') | |||||
| jackhmmer_mgnify_result = run_msa_tool( | |||||
| self.jackhmmer_mgnify_runner, | |||||
| input_fasta_path, | |||||
| mgnify_out_path, | |||||
| 'sto', | |||||
| self.use_precomputed_msas, | |||||
| ) | |||||
| msa_for_templates = jackhmmer_uniref90_result['sto'] | |||||
| msa_for_templates = parsers.truncate_stockholm_msa( | |||||
| msa_for_templates, max_sequences=self.uniref_max_hits) | |||||
| msa_for_templates = parsers.deduplicate_stockholm_msa( | |||||
| msa_for_templates) | |||||
| msa_for_templates = parsers.remove_empty_columns_from_stockholm_msa( | |||||
| msa_for_templates) | |||||
| if self.template_searcher.input_format == 'sto': | |||||
| pdb_templates_result = self.template_searcher.query( | |||||
| msa_for_templates) | |||||
| elif self.template_searcher.input_format == 'a3m': | |||||
| uniref90_msa_as_a3m = parsers.convert_stockholm_to_a3m( | |||||
| msa_for_templates) | |||||
| pdb_templates_result = self.template_searcher.query( | |||||
| uniref90_msa_as_a3m) | |||||
| else: | |||||
| raise ValueError('Unrecognized template input format: ' | |||||
| f'{self.template_searcher.input_format}') | |||||
| pdb_hits_out_path = os.path.join( | |||||
| msa_output_dir, f'pdb_hits.{self.template_searcher.output_format}') | |||||
| with open(pdb_hits_out_path, 'w') as f: | |||||
| f.write(pdb_templates_result) | |||||
| uniref90_msa = parsers.parse_stockholm( | |||||
| jackhmmer_uniref90_result['sto']) | |||||
| uniref90_msa = uniref90_msa.truncate(max_seqs=self.uniref_max_hits) | |||||
| mgnify_msa = parsers.parse_stockholm(jackhmmer_mgnify_result['sto']) | |||||
| mgnify_msa = mgnify_msa.truncate(max_seqs=self.mgnify_max_hits) | |||||
| pdb_template_hits = self.template_searcher.get_template_hits( | |||||
| output_string=pdb_templates_result, input_sequence=input_sequence) | |||||
| if self._use_small_bfd: | |||||
| bfd_out_path = os.path.join(msa_output_dir, 'small_bfd_hits.sto') | |||||
| jackhmmer_small_bfd_result = run_msa_tool( | |||||
| self.jackhmmer_small_bfd_runner, | |||||
| input_fasta_path, | |||||
| bfd_out_path, | |||||
| 'sto', | |||||
| self.use_precomputed_msas, | |||||
| ) | |||||
| bfd_msa = parsers.parse_stockholm( | |||||
| jackhmmer_small_bfd_result['sto']) | |||||
| else: | |||||
| bfd_out_path = os.path.join(msa_output_dir, | |||||
| 'bfd_uniclust_hits.a3m') | |||||
| hhblits_bfd_uniclust_result = run_msa_tool( | |||||
| self.hhblits_bfd_uniclust_runner, | |||||
| input_fasta_path, | |||||
| bfd_out_path, | |||||
| 'a3m', | |||||
| self.use_precomputed_msas, | |||||
| ) | |||||
| bfd_msa = parsers.parse_a3m(hhblits_bfd_uniclust_result['a3m']) | |||||
| templates_result = self.template_featurizer.get_templates( | |||||
| query_sequence=input_sequence, hits=pdb_template_hits) | |||||
| sequence_features = make_sequence_features( | |||||
| sequence=input_sequence, | |||||
| description=input_description, | |||||
| num_res=num_res) | |||||
| msa_features = make_msa_features((uniref90_msa, bfd_msa, mgnify_msa)) | |||||
| logging.info('Uniref90 MSA size: %d sequences.', len(uniref90_msa)) | |||||
| logging.info('BFD MSA size: %d sequences.', len(bfd_msa)) | |||||
| logging.info('MGnify MSA size: %d sequences.', len(mgnify_msa)) | |||||
| logging.info( | |||||
| 'Final (deduplicated) MSA size: %d sequences.', | |||||
| msa_features['num_alignments'][0], | |||||
| ) | |||||
| logging.info( | |||||
| 'Total number of templates (NB: this can include bad ' | |||||
| 'templates and is later filtered to top 4): %d.', | |||||
| templates_result.features['template_domain_names'].shape[0], | |||||
| ) | |||||
| return { | |||||
| **sequence_features, | |||||
| **msa_features, | |||||
| **templates_result.features | |||||
| } | |||||
| def process_uniprot(self, input_fasta_path: str, | |||||
| msa_output_dir: str) -> FeatureDict: | |||||
| uniprot_path = os.path.join(msa_output_dir, 'uniprot_hits.sto') | |||||
| uniprot_result = run_msa_tool( | |||||
| self.jackhmmer_uniprot_runner, | |||||
| input_fasta_path, | |||||
| uniprot_path, | |||||
| 'sto', | |||||
| self.use_precomputed_msas, | |||||
| ) | |||||
| msa = parsers.parse_stockholm(uniprot_result['sto']) | |||||
| msa = msa.truncate(max_seqs=50000) | |||||
| all_seq_dict = make_msa_features([msa]) | |||||
| return all_seq_dict | |||||
| @@ -0,0 +1,14 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Python wrappers for third party tools.""" | |||||
| @@ -0,0 +1,170 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Library to run HHblits from Python.""" | |||||
| import glob | |||||
| import os | |||||
| import subprocess | |||||
| from typing import Any, List, Mapping, Optional, Sequence | |||||
| from absl import logging | |||||
| from . import utils | |||||
| _HHBLITS_DEFAULT_P = 20 | |||||
| _HHBLITS_DEFAULT_Z = 500 | |||||
| class HHBlits: | |||||
| """Python wrapper of the HHblits binary.""" | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| binary_path: str, | |||||
| databases: Sequence[str], | |||||
| n_cpu: int = 4, | |||||
| n_iter: int = 3, | |||||
| e_value: float = 0.001, | |||||
| maxseq: int = 1_000_000, | |||||
| realign_max: int = 100_000, | |||||
| maxfilt: int = 100_000, | |||||
| min_prefilter_hits: int = 1000, | |||||
| all_seqs: bool = False, | |||||
| alt: Optional[int] = None, | |||||
| p: int = _HHBLITS_DEFAULT_P, | |||||
| z: int = _HHBLITS_DEFAULT_Z, | |||||
| ): | |||||
| """Initializes the Python HHblits wrapper. | |||||
| Args: | |||||
| binary_path: The path to the HHblits executable. | |||||
| databases: A sequence of HHblits database paths. This should be the | |||||
| common prefix for the database files (i.e. up to but not including | |||||
| _hhm.ffindex etc.) | |||||
| n_cpu: The number of CPUs to give HHblits. | |||||
| n_iter: The number of HHblits iterations. | |||||
| e_value: The E-value, see HHblits docs for more details. | |||||
| maxseq: The maximum number of rows in an input alignment. Note that this | |||||
| parameter is only supported in HHBlits version 3.1 and higher. | |||||
| realign_max: Max number of HMM-HMM hits to realign. HHblits default: 500. | |||||
| maxfilt: Max number of hits allowed to pass the 2nd prefilter. | |||||
| HHblits default: 20000. | |||||
| min_prefilter_hits: Min number of hits to pass prefilter. | |||||
| HHblits default: 100. | |||||
| all_seqs: Return all sequences in the MSA / Do not filter the result MSA. | |||||
| HHblits default: False. | |||||
| alt: Show up to this many alternative alignments. | |||||
| p: Minimum Prob for a hit to be included in the output hhr file. | |||||
| HHblits default: 20. | |||||
| z: Hard cap on number of hits reported in the hhr file. | |||||
| HHblits default: 500. NB: The relevant HHblits flag is -Z not -z. | |||||
| Raises: | |||||
| RuntimeError: If HHblits binary not found within the path. | |||||
| """ | |||||
| self.binary_path = binary_path | |||||
| self.databases = databases | |||||
| for database_path in self.databases: | |||||
| if not glob.glob(database_path + '_*'): | |||||
| logging.error('Could not find HHBlits database %s', | |||||
| database_path) | |||||
| raise ValueError( | |||||
| f'Could not find HHBlits database {database_path}') | |||||
| self.n_cpu = n_cpu | |||||
| self.n_iter = n_iter | |||||
| self.e_value = e_value | |||||
| self.maxseq = maxseq | |||||
| self.realign_max = realign_max | |||||
| self.maxfilt = maxfilt | |||||
| self.min_prefilter_hits = min_prefilter_hits | |||||
| self.all_seqs = all_seqs | |||||
| self.alt = alt | |||||
| self.p = p | |||||
| self.z = z | |||||
| def query(self, input_fasta_path: str) -> List[Mapping[str, Any]]: | |||||
| """Queries the database using HHblits.""" | |||||
| with utils.tmpdir_manager() as query_tmp_dir: | |||||
| a3m_path = os.path.join(query_tmp_dir, 'output.a3m') | |||||
| db_cmd = [] | |||||
| for db_path in self.databases: | |||||
| db_cmd.append('-d') | |||||
| db_cmd.append(db_path) | |||||
| cmd = [ | |||||
| self.binary_path, | |||||
| '-i', | |||||
| input_fasta_path, | |||||
| '-cpu', | |||||
| str(self.n_cpu), | |||||
| '-oa3m', | |||||
| a3m_path, | |||||
| '-o', | |||||
| '/dev/null', | |||||
| '-n', | |||||
| str(self.n_iter), | |||||
| '-e', | |||||
| str(self.e_value), | |||||
| '-maxseq', | |||||
| str(self.maxseq), | |||||
| '-realign_max', | |||||
| str(self.realign_max), | |||||
| '-maxfilt', | |||||
| str(self.maxfilt), | |||||
| '-min_prefilter_hits', | |||||
| str(self.min_prefilter_hits), | |||||
| ] | |||||
| if self.all_seqs: | |||||
| cmd += ['-all'] | |||||
| if self.alt: | |||||
| cmd += ['-alt', str(self.alt)] | |||||
| if self.p != _HHBLITS_DEFAULT_P: | |||||
| cmd += ['-p', str(self.p)] | |||||
| if self.z != _HHBLITS_DEFAULT_Z: | |||||
| cmd += ['-Z', str(self.z)] | |||||
| cmd += db_cmd | |||||
| logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||||
| process = subprocess.Popen( | |||||
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||||
| with utils.timing('HHblits query'): | |||||
| stdout, stderr = process.communicate() | |||||
| retcode = process.wait() | |||||
| if retcode: | |||||
| # Logs have a 15k character limit, so log HHblits error line by line. | |||||
| logging.error('HHblits failed. HHblits stderr begin:') | |||||
| for error_line in stderr.decode('utf-8').splitlines(): | |||||
| if error_line.strip(): | |||||
| logging.error(error_line.strip()) | |||||
| logging.error('HHblits stderr end') | |||||
| raise RuntimeError( | |||||
| 'HHblits failed\nstdout:\n%s\n\nstderr:\n%s\n' % | |||||
| (stdout.decode('utf-8'), stderr[:500_000].decode('utf-8'))) | |||||
| with open(a3m_path) as f: | |||||
| a3m = f.read() | |||||
| raw_output = dict( | |||||
| a3m=a3m, | |||||
| output=stdout, | |||||
| stderr=stderr, | |||||
| n_iter=self.n_iter, | |||||
| e_value=self.e_value, | |||||
| ) | |||||
| return [raw_output] | |||||
| @@ -0,0 +1,111 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Library to run HHsearch from Python.""" | |||||
| import glob | |||||
| import os | |||||
| import subprocess | |||||
| from typing import Sequence | |||||
| from absl import logging | |||||
| from modelscope.models.science.unifold.msa import parsers | |||||
| from . import utils | |||||
| class HHSearch: | |||||
| """Python wrapper of the HHsearch binary.""" | |||||
| def __init__(self, | |||||
| *, | |||||
| binary_path: str, | |||||
| databases: Sequence[str], | |||||
| maxseq: int = 1_000_000): | |||||
| """Initializes the Python HHsearch wrapper. | |||||
| Args: | |||||
| binary_path: The path to the HHsearch executable. | |||||
| databases: A sequence of HHsearch database paths. This should be the | |||||
| common prefix for the database files (i.e. up to but not including | |||||
| _hhm.ffindex etc.) | |||||
| maxseq: The maximum number of rows in an input alignment. Note that this | |||||
| parameter is only supported in HHBlits version 3.1 and higher. | |||||
| Raises: | |||||
| RuntimeError: If HHsearch binary not found within the path. | |||||
| """ | |||||
| self.binary_path = binary_path | |||||
| self.databases = databases | |||||
| self.maxseq = maxseq | |||||
| for database_path in self.databases: | |||||
| if not glob.glob(database_path + '_*'): | |||||
| logging.error('Could not find HHsearch database %s', | |||||
| database_path) | |||||
| raise ValueError( | |||||
| f'Could not find HHsearch database {database_path}') | |||||
| @property | |||||
| def output_format(self) -> str: | |||||
| return 'hhr' | |||||
| @property | |||||
| def input_format(self) -> str: | |||||
| return 'a3m' | |||||
| def query(self, a3m: str) -> str: | |||||
| """Queries the database using HHsearch using a given a3m.""" | |||||
| with utils.tmpdir_manager() as query_tmp_dir: | |||||
| input_path = os.path.join(query_tmp_dir, 'query.a3m') | |||||
| hhr_path = os.path.join(query_tmp_dir, 'output.hhr') | |||||
| with open(input_path, 'w') as f: | |||||
| f.write(a3m) | |||||
| db_cmd = [] | |||||
| for db_path in self.databases: | |||||
| db_cmd.append('-d') | |||||
| db_cmd.append(db_path) | |||||
| cmd = [ | |||||
| self.binary_path, | |||||
| '-i', | |||||
| input_path, | |||||
| '-o', | |||||
| hhr_path, | |||||
| '-maxseq', | |||||
| str(self.maxseq), | |||||
| ] + db_cmd | |||||
| logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||||
| process = subprocess.Popen( | |||||
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||||
| with utils.timing('HHsearch query'): | |||||
| stdout, stderr = process.communicate() | |||||
| retcode = process.wait() | |||||
| if retcode: | |||||
| # Stderr is truncated to prevent proto size errors in Beam. | |||||
| raise RuntimeError( | |||||
| 'HHSearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % | |||||
| (stdout.decode('utf-8'), stderr[:100_000].decode('utf-8'))) | |||||
| with open(hhr_path) as f: | |||||
| hhr = f.read() | |||||
| return hhr | |||||
| def get_template_hits( | |||||
| self, output_string: str, | |||||
| input_sequence: str) -> Sequence[parsers.TemplateHit]: | |||||
| """Gets parsed template hits from the raw string output by the tool.""" | |||||
| del input_sequence # Used by hmmseach but not needed for hhsearch. | |||||
| return parsers.parse_hhr(output_string) | |||||
| @@ -0,0 +1,143 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" | |||||
| import os | |||||
| import re | |||||
| import subprocess | |||||
| from absl import logging | |||||
| from . import utils | |||||
| class Hmmbuild(object): | |||||
| """Python wrapper of the hmmbuild binary.""" | |||||
| def __init__(self, *, binary_path: str, singlemx: bool = False): | |||||
| """Initializes the Python hmmbuild wrapper. | |||||
| Args: | |||||
| binary_path: The path to the hmmbuild executable. | |||||
| singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to | |||||
| just use a common substitution score matrix. | |||||
| Raises: | |||||
| RuntimeError: If hmmbuild binary not found within the path. | |||||
| """ | |||||
| self.binary_path = binary_path | |||||
| self.singlemx = singlemx | |||||
| def build_profile_from_sto(self, | |||||
| sto: str, | |||||
| model_construction='fast') -> str: | |||||
| """Builds a HHM for the aligned sequences given as an A3M string. | |||||
| Args: | |||||
| sto: A string with the aligned sequences in the Stockholm format. | |||||
| model_construction: Whether to use reference annotation in the msa to | |||||
| determine consensus columns ('hand') or default ('fast'). | |||||
| Returns: | |||||
| A string with the profile in the HMM format. | |||||
| Raises: | |||||
| RuntimeError: If hmmbuild fails. | |||||
| """ | |||||
| return self._build_profile(sto, model_construction=model_construction) | |||||
| def build_profile_from_a3m(self, a3m: str) -> str: | |||||
| """Builds a HHM for the aligned sequences given as an A3M string. | |||||
| Args: | |||||
| a3m: A string with the aligned sequences in the A3M format. | |||||
| Returns: | |||||
| A string with the profile in the HMM format. | |||||
| Raises: | |||||
| RuntimeError: If hmmbuild fails. | |||||
| """ | |||||
| lines = [] | |||||
| for line in a3m.splitlines(): | |||||
| if not line.startswith('>'): | |||||
| line = re.sub('[a-z]+', '', line) # Remove inserted residues. | |||||
| lines.append(line + '\n') | |||||
| msa = ''.join(lines) | |||||
| return self._build_profile(msa, model_construction='fast') | |||||
| def _build_profile(self, | |||||
| msa: str, | |||||
| model_construction: str = 'fast') -> str: | |||||
| """Builds a HMM for the aligned sequences given as an MSA string. | |||||
| Args: | |||||
| msa: A string with the aligned sequences, in A3M or STO format. | |||||
| model_construction: Whether to use reference annotation in the msa to | |||||
| determine consensus columns ('hand') or default ('fast'). | |||||
| Returns: | |||||
| A string with the profile in the HMM format. | |||||
| Raises: | |||||
| RuntimeError: If hmmbuild fails. | |||||
| ValueError: If unspecified arguments are provided. | |||||
| """ | |||||
| if model_construction not in {'hand', 'fast'}: | |||||
| raise ValueError( | |||||
| f'Invalid model_construction {model_construction} - only' | |||||
| 'hand and fast supported.') | |||||
| with utils.tmpdir_manager() as query_tmp_dir: | |||||
| input_query = os.path.join(query_tmp_dir, 'query.msa') | |||||
| output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') | |||||
| with open(input_query, 'w') as f: | |||||
| f.write(msa) | |||||
| cmd = [self.binary_path] | |||||
| # If adding flags, we have to do so before the output and input: | |||||
| if model_construction == 'hand': | |||||
| cmd.append(f'--{model_construction}') | |||||
| if self.singlemx: | |||||
| cmd.append('--singlemx') | |||||
| cmd.extend([ | |||||
| '--amino', | |||||
| output_hmm_path, | |||||
| input_query, | |||||
| ]) | |||||
| logging.info('Launching subprocess %s', cmd) | |||||
| process = subprocess.Popen( | |||||
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||||
| with utils.timing('hmmbuild query'): | |||||
| stdout, stderr = process.communicate() | |||||
| retcode = process.wait() | |||||
| logging.info( | |||||
| 'hmmbuild stdout:\n%s\n\nstderr:\n%s\n', | |||||
| stdout.decode('utf-8'), | |||||
| stderr.decode('utf-8'), | |||||
| ) | |||||
| if retcode: | |||||
| raise RuntimeError( | |||||
| 'hmmbuild failed\nstdout:\n%s\n\nstderr:\n%s\n' % | |||||
| (stdout.decode('utf-8'), stderr.decode('utf-8'))) | |||||
| with open(output_hmm_path, encoding='utf-8') as f: | |||||
| hmm = f.read() | |||||
| return hmm | |||||
| @@ -0,0 +1,146 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """A Python wrapper for hmmsearch - search profile against a sequence db.""" | |||||
| import os | |||||
| import subprocess | |||||
| from typing import Optional, Sequence | |||||
| from absl import logging | |||||
| from modelscope.models.science.unifold.msa import parsers | |||||
| from . import hmmbuild, utils | |||||
| class Hmmsearch(object): | |||||
| """Python wrapper of the hmmsearch binary.""" | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| binary_path: str, | |||||
| hmmbuild_binary_path: str, | |||||
| database_path: str, | |||||
| flags: Optional[Sequence[str]] = None, | |||||
| ): | |||||
| """Initializes the Python hmmsearch wrapper. | |||||
| Args: | |||||
| binary_path: The path to the hmmsearch executable. | |||||
| hmmbuild_binary_path: The path to the hmmbuild executable. Used to build | |||||
| an hmm from an input a3m. | |||||
| database_path: The path to the hmmsearch database (FASTA format). | |||||
| flags: List of flags to be used by hmmsearch. | |||||
| Raises: | |||||
| RuntimeError: If hmmsearch binary not found within the path. | |||||
| """ | |||||
| self.binary_path = binary_path | |||||
| self.hmmbuild_runner = hmmbuild.Hmmbuild( | |||||
| binary_path=hmmbuild_binary_path) | |||||
| self.database_path = database_path | |||||
| if flags is None: | |||||
| # Default hmmsearch run settings. | |||||
| flags = [ | |||||
| '--F1', | |||||
| '0.1', | |||||
| '--F2', | |||||
| '0.1', | |||||
| '--F3', | |||||
| '0.1', | |||||
| '--incE', | |||||
| '100', | |||||
| '-E', | |||||
| '100', | |||||
| '--domE', | |||||
| '100', | |||||
| '--incdomE', | |||||
| '100', | |||||
| ] | |||||
| self.flags = flags | |||||
| if not os.path.exists(self.database_path): | |||||
| logging.error('Could not find hmmsearch database %s', | |||||
| database_path) | |||||
| raise ValueError( | |||||
| f'Could not find hmmsearch database {database_path}') | |||||
| @property | |||||
| def output_format(self) -> str: | |||||
| return 'sto' | |||||
| @property | |||||
| def input_format(self) -> str: | |||||
| return 'sto' | |||||
| def query(self, msa_sto: str) -> str: | |||||
| """Queries the database using hmmsearch using a given stockholm msa.""" | |||||
| hmm = self.hmmbuild_runner.build_profile_from_sto( | |||||
| msa_sto, model_construction='hand') | |||||
| return self.query_with_hmm(hmm) | |||||
| def query_with_hmm(self, hmm: str) -> str: | |||||
| """Queries the database using hmmsearch using a given hmm.""" | |||||
| with utils.tmpdir_manager() as query_tmp_dir: | |||||
| hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') | |||||
| out_path = os.path.join(query_tmp_dir, 'output.sto') | |||||
| with open(hmm_input_path, 'w') as f: | |||||
| f.write(hmm) | |||||
| cmd = [ | |||||
| self.binary_path, | |||||
| '--noali', # Don't include the alignment in stdout. | |||||
| '--cpu', | |||||
| '8', | |||||
| ] | |||||
| # If adding flags, we have to do so before the output and input: | |||||
| if self.flags: | |||||
| cmd.extend(self.flags) | |||||
| cmd.extend([ | |||||
| '-A', | |||||
| out_path, | |||||
| hmm_input_path, | |||||
| self.database_path, | |||||
| ]) | |||||
| logging.info('Launching sub-process %s', cmd) | |||||
| process = subprocess.Popen( | |||||
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||||
| with utils.timing( | |||||
| f'hmmsearch ({os.path.basename(self.database_path)}) query' | |||||
| ): | |||||
| stdout, stderr = process.communicate() | |||||
| retcode = process.wait() | |||||
| if retcode: | |||||
| raise RuntimeError( | |||||
| 'hmmsearch failed:\nstdout:\n%s\n\nstderr:\n%s\n' % | |||||
| (stdout.decode('utf-8'), stderr.decode('utf-8'))) | |||||
| with open(out_path) as f: | |||||
| out_msa = f.read() | |||||
| return out_msa | |||||
| def get_template_hits( | |||||
| self, output_string: str, | |||||
| input_sequence: str) -> Sequence[parsers.TemplateHit]: | |||||
| """Gets parsed template hits from the raw string output by the tool.""" | |||||
| a3m_string = parsers.convert_stockholm_to_a3m( | |||||
| output_string, remove_first_row_gaps=False) | |||||
| template_hits = parsers.parse_hmmsearch_a3m( | |||||
| query_sequence=input_sequence, | |||||
| a3m_string=a3m_string, | |||||
| skip_first=False) | |||||
| return template_hits | |||||
| @@ -0,0 +1,224 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Library to run Jackhmmer from Python.""" | |||||
| import glob | |||||
| import os | |||||
| import subprocess | |||||
| from concurrent import futures | |||||
| from typing import Any, Callable, Mapping, Optional, Sequence | |||||
| from urllib import request | |||||
| from absl import logging | |||||
| from . import utils | |||||
| class Jackhmmer: | |||||
| """Python wrapper of the Jackhmmer binary.""" | |||||
| def __init__( | |||||
| self, | |||||
| *, | |||||
| binary_path: str, | |||||
| database_path: str, | |||||
| n_cpu: int = 8, | |||||
| n_iter: int = 1, | |||||
| e_value: float = 0.0001, | |||||
| z_value: Optional[int] = None, | |||||
| get_tblout: bool = False, | |||||
| filter_f1: float = 0.0005, | |||||
| filter_f2: float = 0.00005, | |||||
| filter_f3: float = 0.0000005, | |||||
| incdom_e: Optional[float] = None, | |||||
| dom_e: Optional[float] = None, | |||||
| num_streamed_chunks: Optional[int] = None, | |||||
| streaming_callback: Optional[Callable[[int], None]] = None, | |||||
| ): | |||||
| """Initializes the Python Jackhmmer wrapper. | |||||
| Args: | |||||
| binary_path: The path to the jackhmmer executable. | |||||
| database_path: The path to the jackhmmer database (FASTA format). | |||||
| n_cpu: The number of CPUs to give Jackhmmer. | |||||
| n_iter: The number of Jackhmmer iterations. | |||||
| e_value: The E-value, see Jackhmmer docs for more details. | |||||
| z_value: The Z-value, see Jackhmmer docs for more details. | |||||
| get_tblout: Whether to save tblout string. | |||||
| filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. | |||||
| filter_f2: Viterbi pre-filter, set to >1.0 to turn off. | |||||
| filter_f3: Forward pre-filter, set to >1.0 to turn off. | |||||
| incdom_e: Domain e-value criteria for inclusion of domains in MSA/next | |||||
| round. | |||||
| dom_e: Domain e-value criteria for inclusion in tblout. | |||||
| num_streamed_chunks: Number of database chunks to stream over. | |||||
| streaming_callback: Callback function run after each chunk iteration with | |||||
| the iteration number as argument. | |||||
| """ | |||||
| self.binary_path = binary_path | |||||
| self.database_path = database_path | |||||
| self.num_streamed_chunks = num_streamed_chunks | |||||
| if not os.path.exists( | |||||
| self.database_path) and num_streamed_chunks is None: | |||||
| logging.error('Could not find Jackhmmer database %s', | |||||
| database_path) | |||||
| raise ValueError( | |||||
| f'Could not find Jackhmmer database {database_path}') | |||||
| self.n_cpu = n_cpu | |||||
| self.n_iter = n_iter | |||||
| self.e_value = e_value | |||||
| self.z_value = z_value | |||||
| self.filter_f1 = filter_f1 | |||||
| self.filter_f2 = filter_f2 | |||||
| self.filter_f3 = filter_f3 | |||||
| self.incdom_e = incdom_e | |||||
| self.dom_e = dom_e | |||||
| self.get_tblout = get_tblout | |||||
| self.streaming_callback = streaming_callback | |||||
| def _query_chunk(self, input_fasta_path: str, | |||||
| database_path: str) -> Mapping[str, Any]: | |||||
| """Queries the database chunk using Jackhmmer.""" | |||||
| with utils.tmpdir_manager() as query_tmp_dir: | |||||
| sto_path = os.path.join(query_tmp_dir, 'output.sto') | |||||
| # The F1/F2/F3 are the expected proportion to pass each of the filtering | |||||
| # stages (which get progressively more expensive), reducing these | |||||
| # speeds up the pipeline at the expensive of sensitivity. They are | |||||
| # currently set very low to make querying Mgnify run in a reasonable | |||||
| # amount of time. | |||||
| cmd_flags = [ | |||||
| # Don't pollute stdout with Jackhmmer output. | |||||
| '-o', | |||||
| '/dev/null', | |||||
| '-A', | |||||
| sto_path, | |||||
| '--noali', | |||||
| '--F1', | |||||
| str(self.filter_f1), | |||||
| '--F2', | |||||
| str(self.filter_f2), | |||||
| '--F3', | |||||
| str(self.filter_f3), | |||||
| '--incE', | |||||
| str(self.e_value), | |||||
| # Report only sequences with E-values <= x in per-sequence output. | |||||
| '-E', | |||||
| str(self.e_value), | |||||
| '--cpu', | |||||
| str(self.n_cpu), | |||||
| '-N', | |||||
| str(self.n_iter), | |||||
| ] | |||||
| if self.get_tblout: | |||||
| tblout_path = os.path.join(query_tmp_dir, 'tblout.txt') | |||||
| cmd_flags.extend(['--tblout', tblout_path]) | |||||
| if self.z_value: | |||||
| cmd_flags.extend(['-Z', str(self.z_value)]) | |||||
| if self.dom_e is not None: | |||||
| cmd_flags.extend(['--domE', str(self.dom_e)]) | |||||
| if self.incdom_e is not None: | |||||
| cmd_flags.extend(['--incdomE', str(self.incdom_e)]) | |||||
| cmd = [self.binary_path | |||||
| ] + cmd_flags + [input_fasta_path, database_path] | |||||
| logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||||
| process = subprocess.Popen( | |||||
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||||
| with utils.timing( | |||||
| f'Jackhmmer ({os.path.basename(database_path)}) query'): | |||||
| _, stderr = process.communicate() | |||||
| retcode = process.wait() | |||||
| if retcode: | |||||
| raise RuntimeError('Jackhmmer failed\nstderr:\n%s\n' | |||||
| % stderr.decode('utf-8')) | |||||
| # Get e-values for each target name | |||||
| tbl = '' | |||||
| if self.get_tblout: | |||||
| with open(tblout_path) as f: | |||||
| tbl = f.read() | |||||
| with open(sto_path) as f: | |||||
| sto = f.read() | |||||
| raw_output = dict( | |||||
| sto=sto, | |||||
| tbl=tbl, | |||||
| stderr=stderr, | |||||
| n_iter=self.n_iter, | |||||
| e_value=self.e_value) | |||||
| return raw_output | |||||
| def query(self, input_fasta_path: str) -> Sequence[Mapping[str, Any]]: | |||||
| """Queries the database using Jackhmmer.""" | |||||
| if self.num_streamed_chunks is None: | |||||
| return [self._query_chunk(input_fasta_path, self.database_path)] | |||||
| db_basename = os.path.basename(self.database_path) | |||||
| def db_remote_chunk(db_idx): | |||||
| return f'{self.database_path}.{db_idx}' | |||||
| def db_local_chunk(db_idx): | |||||
| return f'/tmp/ramdisk/{db_basename}.{db_idx}' | |||||
| # db_remote_chunk = lambda db_idx: f'{self.database_path}.{db_idx}' | |||||
| # db_local_chunk = lambda db_idx: f'/tmp/ramdisk/{db_basename}.{db_idx}' | |||||
| # Remove existing files to prevent OOM | |||||
| for f in glob.glob(db_local_chunk('[0-9]*')): | |||||
| try: | |||||
| os.remove(f) | |||||
| except OSError: | |||||
| print(f'OSError while deleting {f}') | |||||
| # Download the (i+1)-th chunk while Jackhmmer is running on the i-th chunk | |||||
| with futures.ThreadPoolExecutor(max_workers=2) as executor: | |||||
| chunked_output = [] | |||||
| for i in range(1, self.num_streamed_chunks + 1): | |||||
| # Copy the chunk locally | |||||
| if i == 1: | |||||
| future = executor.submit(request.urlretrieve, | |||||
| db_remote_chunk(i), | |||||
| db_local_chunk(i)) | |||||
| if i < self.num_streamed_chunks: | |||||
| next_future = executor.submit( | |||||
| request.urlretrieve, | |||||
| db_remote_chunk(i + 1), | |||||
| db_local_chunk(i + 1), | |||||
| ) | |||||
| # Run Jackhmmer with the chunk | |||||
| future.result() | |||||
| chunked_output.append( | |||||
| self._query_chunk(input_fasta_path, db_local_chunk(i))) | |||||
| # Remove the local copy of the chunk | |||||
| os.remove(db_local_chunk(i)) | |||||
| # Do not set next_future for the last chunk so that this works even for | |||||
| # databases with only 1 chunk. | |||||
| if i < self.num_streamed_chunks: | |||||
| future = next_future | |||||
| if self.streaming_callback: | |||||
| self.streaming_callback(i) | |||||
| return chunked_output | |||||
| @@ -0,0 +1,110 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """A Python wrapper for Kalign.""" | |||||
| import os | |||||
| import subprocess | |||||
| from typing import Sequence | |||||
| from absl import logging | |||||
| from . import utils | |||||
| def _to_a3m(sequences: Sequence[str]) -> str: | |||||
| """Converts sequences to an a3m file.""" | |||||
| names = ['sequence %d' % i for i in range(1, len(sequences) + 1)] | |||||
| a3m = [] | |||||
| for sequence, name in zip(sequences, names): | |||||
| a3m.append('>' + name + '\n') | |||||
| a3m.append(sequence + '\n') | |||||
| return ''.join(a3m) | |||||
| class Kalign: | |||||
| """Python wrapper of the Kalign binary.""" | |||||
| def __init__(self, *, binary_path: str): | |||||
| """Initializes the Python Kalign wrapper. | |||||
| Args: | |||||
| binary_path: The path to the Kalign binary. | |||||
| Raises: | |||||
| RuntimeError: If Kalign binary not found within the path. | |||||
| """ | |||||
| self.binary_path = binary_path | |||||
| def align(self, sequences: Sequence[str]) -> str: | |||||
| """Aligns the sequences and returns the alignment in A3M string. | |||||
| Args: | |||||
| sequences: A list of query sequence strings. The sequences have to be at | |||||
| least 6 residues long (Kalign requires this). Note that the order in | |||||
| which you give the sequences might alter the output slightly as | |||||
| different alignment tree might get constructed. | |||||
| Returns: | |||||
| A string with the alignment in a3m format. | |||||
| Raises: | |||||
| RuntimeError: If Kalign fails. | |||||
| ValueError: If any of the sequences is less than 6 residues long. | |||||
| """ | |||||
| logging.info('Aligning %d sequences', len(sequences)) | |||||
| for s in sequences: | |||||
| if len(s) < 6: | |||||
| raise ValueError( | |||||
| 'Kalign requires all sequences to be at least 6 ' | |||||
| 'residues long. Got %s (%d residues).' % (s, len(s))) | |||||
| with utils.tmpdir_manager() as query_tmp_dir: | |||||
| input_fasta_path = os.path.join(query_tmp_dir, 'input.fasta') | |||||
| output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') | |||||
| with open(input_fasta_path, 'w') as f: | |||||
| f.write(_to_a3m(sequences)) | |||||
| cmd = [ | |||||
| self.binary_path, | |||||
| '-i', | |||||
| input_fasta_path, | |||||
| '-o', | |||||
| output_a3m_path, | |||||
| '-format', | |||||
| 'fasta', | |||||
| ] | |||||
| logging.info('Launching subprocess "%s"', ' '.join(cmd)) | |||||
| process = subprocess.Popen( | |||||
| cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |||||
| with utils.timing('Kalign query'): | |||||
| stdout, stderr = process.communicate() | |||||
| retcode = process.wait() | |||||
| logging.info( | |||||
| 'Kalign stdout:\n%s\n\nstderr:\n%s\n', | |||||
| stdout.decode('utf-8'), | |||||
| stderr.decode('utf-8'), | |||||
| ) | |||||
| if retcode: | |||||
| raise RuntimeError( | |||||
| 'Kalign failed\nstdout:\n%s\n\nstderr:\n%s\n' % | |||||
| (stdout.decode('utf-8'), stderr.decode('utf-8'))) | |||||
| with open(output_a3m_path) as f: | |||||
| a3m = f.read() | |||||
| return a3m | |||||
| @@ -0,0 +1,40 @@ | |||||
| # Copyright 2021 DeepMind Technologies Limited | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| """Common utilities for data pipeline tools.""" | |||||
| import contextlib | |||||
| import shutil | |||||
| import tempfile | |||||
| import time | |||||
| from typing import Optional | |||||
| from absl import logging | |||||
| @contextlib.contextmanager | |||||
| def tmpdir_manager(base_dir: Optional[str] = None): | |||||
| """Context manager that deletes a temporary directory on exit.""" | |||||
| tmpdir = tempfile.mkdtemp(dir=base_dir) | |||||
| try: | |||||
| yield tmpdir | |||||
| finally: | |||||
| shutil.rmtree(tmpdir, ignore_errors=True) | |||||
| @contextlib.contextmanager | |||||
| def timing(msg: str): | |||||
| logging.info('Started %s', msg) | |||||
| tic = time.time() | |||||
| yield | |||||
| toc = time.time() | |||||
| logging.info('Finished %s in %.3f seconds', msg, toc - tic) | |||||
| @@ -0,0 +1,89 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| import os | |||||
| from typing import Mapping, Sequence | |||||
| import json | |||||
| from absl import logging | |||||
| from modelscope.models.science.unifold.data import protein | |||||
| def get_chain_id_map( | |||||
| sequences: Sequence[str], | |||||
| descriptions: Sequence[str], | |||||
| ): | |||||
| """ | |||||
| Makes a mapping from PDB-format chain ID to sequence and description, | |||||
| and parses the order of multi-chains | |||||
| """ | |||||
| unique_seqs = [] | |||||
| for seq in sequences: | |||||
| if seq not in unique_seqs: | |||||
| unique_seqs.append(seq) | |||||
| chain_id_map = { | |||||
| chain_id: { | |||||
| 'descriptions': [], | |||||
| 'sequence': seq | |||||
| } | |||||
| for chain_id, seq in zip(protein.PDB_CHAIN_IDS, unique_seqs) | |||||
| } | |||||
| chain_order = [] | |||||
| for seq, des in zip(sequences, descriptions): | |||||
| chain_id = protein.PDB_CHAIN_IDS[unique_seqs.index(seq)] | |||||
| chain_id_map[chain_id]['descriptions'].append(des) | |||||
| chain_order.append(chain_id) | |||||
| return chain_id_map, chain_order | |||||
| def divide_multi_chains( | |||||
| fasta_name: str, | |||||
| output_dir_base: str, | |||||
| sequences: Sequence[str], | |||||
| descriptions: Sequence[str], | |||||
| ): | |||||
| """ | |||||
| Divides the multi-chains fasta into several single fasta files and | |||||
| records multi-chains mapping information. | |||||
| """ | |||||
| if len(sequences) != len(descriptions): | |||||
| raise ValueError('sequences and descriptions must have equal length. ' | |||||
| f'Got {len(sequences)} != {len(descriptions)}.') | |||||
| if len(sequences) > protein.PDB_MAX_CHAINS: | |||||
| raise ValueError( | |||||
| 'Cannot process more chains than the PDB format supports. ' | |||||
| f'Got {len(sequences)} chains.') | |||||
| chain_id_map, chain_order = get_chain_id_map(sequences, descriptions) | |||||
| output_dir = os.path.join(output_dir_base, fasta_name) | |||||
| if not os.path.exists(output_dir): | |||||
| os.makedirs(output_dir) | |||||
| chain_id_map_path = os.path.join(output_dir, 'chain_id_map.json') | |||||
| with open(chain_id_map_path, 'w') as f: | |||||
| json.dump(chain_id_map, f, indent=4, sort_keys=True) | |||||
| chain_order_path = os.path.join(output_dir, 'chains.txt') | |||||
| with open(chain_order_path, 'w') as f: | |||||
| f.write(' '.join(chain_order)) | |||||
| logging.info('Mapping multi-chains fasta with chain order: %s', | |||||
| ' '.join(chain_order)) | |||||
| temp_names = [] | |||||
| temp_paths = [] | |||||
| for chain_id in chain_id_map.keys(): | |||||
| temp_name = fasta_name + '_{}'.format(chain_id) | |||||
| temp_path = os.path.join(output_dir, temp_name + '.fasta') | |||||
| des = 'chain_{}'.format(chain_id) | |||||
| seq = chain_id_map[chain_id]['sequence'] | |||||
| with open(temp_path, 'w') as f: | |||||
| f.write('>' + des + '\n' + seq) | |||||
| temp_names.append(temp_name) | |||||
| temp_paths.append(temp_path) | |||||
| return temp_names, temp_paths | |||||
| @@ -0,0 +1,22 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .protein_structure_pipeline import ProteinStructurePipeline | |||||
| else: | |||||
| _import_structure = { | |||||
| 'protein_structure_pipeline': ['ProteinStructurePipeline'] | |||||
| } | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,215 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import os | |||||
| import time | |||||
| from typing import Any, Dict, List, Optional, Union | |||||
| import json | |||||
| import numpy as np | |||||
| import torch | |||||
| from unicore.utils import tensor_tree_map | |||||
| from modelscope.metainfo import Pipelines | |||||
| from modelscope.models.base import Model | |||||
| from modelscope.models.science.unifold.config import model_config | |||||
| from modelscope.models.science.unifold.data import protein, residue_constants | |||||
| from modelscope.models.science.unifold.dataset import (UnifoldDataset, | |||||
| load_and_process) | |||||
| from modelscope.outputs import OutputKeys | |||||
| from modelscope.pipelines.base import Pipeline, Tensor | |||||
| from modelscope.pipelines.builder import PIPELINES | |||||
| from modelscope.preprocessors import Preprocessor, build_preprocessor | |||||
| from modelscope.utils.constant import Fields, Frameworks, Tasks | |||||
| from modelscope.utils.device import device_placement | |||||
| from modelscope.utils.hub import read_config | |||||
| from modelscope.utils.logger import get_logger | |||||
| logger = get_logger() | |||||
| __all__ = ['ProteinStructurePipeline'] | |||||
| def automatic_chunk_size(seq_len): | |||||
| if seq_len < 512: | |||||
| chunk_size = 256 | |||||
| elif seq_len < 1024: | |||||
| chunk_size = 128 | |||||
| elif seq_len < 2048: | |||||
| chunk_size = 32 | |||||
| elif seq_len < 3072: | |||||
| chunk_size = 16 | |||||
| else: | |||||
| chunk_size = 1 | |||||
| return chunk_size | |||||
| def load_feature_for_one_target( | |||||
| config, | |||||
| data_folder, | |||||
| seed=0, | |||||
| is_multimer=False, | |||||
| use_uniprot=False, | |||||
| symmetry_group=None, | |||||
| ): | |||||
| if not is_multimer: | |||||
| uniprot_msa_dir = None | |||||
| sequence_ids = ['A'] | |||||
| if use_uniprot: | |||||
| uniprot_msa_dir = data_folder | |||||
| else: | |||||
| uniprot_msa_dir = data_folder | |||||
| sequence_ids = open(os.path.join(data_folder, | |||||
| 'chains.txt')).readline().split() | |||||
| if symmetry_group is None: | |||||
| batch, _ = load_and_process( | |||||
| config=config.data, | |||||
| mode='predict', | |||||
| seed=seed, | |||||
| batch_idx=None, | |||||
| data_idx=0, | |||||
| is_distillation=False, | |||||
| sequence_ids=sequence_ids, | |||||
| monomer_feature_dir=data_folder, | |||||
| uniprot_msa_dir=uniprot_msa_dir, | |||||
| ) | |||||
| else: | |||||
| raise NotImplementedError | |||||
| batch = UnifoldDataset.collater([batch]) | |||||
| return batch | |||||
| @PIPELINES.register_module( | |||||
| Tasks.protein_structure, module_name=Pipelines.protein_structure) | |||||
| class ProteinStructurePipeline(Pipeline): | |||||
| def __init__(self, | |||||
| model: Union[Model, str], | |||||
| preprocessor: Optional[Preprocessor] = None, | |||||
| **kwargs): | |||||
| """Use `model` and `preprocessor` to create a protein structure pipeline for prediction. | |||||
| Args: | |||||
| model (str or Model): Supply either a local model dir which supported the protein structure task, | |||||
| or a model id from the model hub, or a torch model instance. | |||||
| preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for | |||||
| the model if supplied. | |||||
| Example: | |||||
| >>> from modelscope.pipelines import pipeline | |||||
| >>> pipeline_ins = pipeline(task='protein-structure', | |||||
| >>> model='DPTech/uni-fold-monomer') | |||||
| >>> protein = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' | |||||
| >>> print(pipeline_ins(protein)) | |||||
| """ | |||||
| import copy | |||||
| model_path = copy.deepcopy(model) if isinstance(model, str) else None | |||||
| cfg = read_config(model_path) # only model is str | |||||
| self.cfg = cfg | |||||
| self.config = model_config( | |||||
| cfg['pipeline']['model_name']) # alphafold config | |||||
| model = model if isinstance( | |||||
| model, Model) else Model.from_pretrained(model_path) | |||||
| self.postprocessor = cfg.pop('postprocessor', None) | |||||
| if preprocessor is None: | |||||
| preprocessor_cfg = cfg.preprocessor | |||||
| preprocessor = build_preprocessor(preprocessor_cfg, Fields.science) | |||||
| model.eval() | |||||
| model.model.inference_mode() | |||||
| model.model_dir = model_path | |||||
| super().__init__(model=model, preprocessor=preprocessor, **kwargs) | |||||
| def _sanitize_parameters(self, **pipeline_parameters): | |||||
| return pipeline_parameters, pipeline_parameters, pipeline_parameters | |||||
| def _process_single(self, input, *args, **kwargs) -> Dict[str, Any]: | |||||
| preprocess_params = kwargs.get('preprocess_params', {}) | |||||
| forward_params = kwargs.get('forward_params', {}) | |||||
| postprocess_params = kwargs.get('postprocess_params', {}) | |||||
| out = self.preprocess(input, **preprocess_params) | |||||
| with device_placement(self.framework, self.device_name): | |||||
| with torch.no_grad(): | |||||
| out = self.forward(out, **forward_params) | |||||
| out = self.postprocess(out, **postprocess_params) | |||||
| return out | |||||
| def forward(self, inputs: Dict[str, Any], | |||||
| **forward_params) -> Dict[str, Any]: | |||||
| plddts = {} | |||||
| ptms = {} | |||||
| output_dir = os.path.join(self.preprocessor.output_dir_base, | |||||
| inputs['target_id']) | |||||
| pdbs = [] | |||||
| for seed in range(self.cfg['pipeline']['times']): | |||||
| cur_seed = hash((42, seed)) % 100000 | |||||
| batch = load_feature_for_one_target( | |||||
| self.config, | |||||
| output_dir, | |||||
| cur_seed, | |||||
| is_multimer=inputs['is_multimer'], | |||||
| use_uniprot=inputs['is_multimer'], | |||||
| symmetry_group=self.preprocessor.symmetry_group, | |||||
| ) | |||||
| seq_len = batch['aatype'].shape[-1] | |||||
| self.model.model.globals.chunk_size = automatic_chunk_size(seq_len) | |||||
| with torch.no_grad(): | |||||
| batch = { | |||||
| k: torch.as_tensor(v, device='cuda:0') | |||||
| for k, v in batch.items() | |||||
| } | |||||
| out = self.model(batch) | |||||
| def to_float(x): | |||||
| if x.dtype == torch.bfloat16 or x.dtype == torch.half: | |||||
| return x.float() | |||||
| else: | |||||
| return x | |||||
| # Toss out the recycling dimensions --- we don't need them anymore | |||||
| batch = tensor_tree_map(lambda t: t[-1, 0, ...], batch) | |||||
| batch = tensor_tree_map(to_float, batch) | |||||
| out = tensor_tree_map(lambda t: t[0, ...], out[0]) | |||||
| out = tensor_tree_map(to_float, out) | |||||
| batch = tensor_tree_map(lambda x: np.array(x.cpu()), batch) | |||||
| out = tensor_tree_map(lambda x: np.array(x.cpu()), out) | |||||
| plddt = out['plddt'] | |||||
| mean_plddt = np.mean(plddt) | |||||
| plddt_b_factors = np.repeat( | |||||
| plddt[..., None], residue_constants.atom_type_num, axis=-1) | |||||
| # TODO: , may need to reorder chains, based on entity_ids | |||||
| cur_protein = protein.from_prediction( | |||||
| features=batch, result=out, b_factors=plddt_b_factors) | |||||
| cur_save_name = (f'{cur_seed}') | |||||
| plddts[cur_save_name] = str(mean_plddt) | |||||
| if inputs[ | |||||
| 'is_multimer'] and self.preprocessor.symmetry_group is None: | |||||
| ptms[cur_save_name] = str(np.mean(out['iptm+ptm'])) | |||||
| with open(os.path.join(output_dir, cur_save_name + '.pdb'), | |||||
| 'w') as f: | |||||
| f.write(protein.to_pdb(cur_protein)) | |||||
| pdbs.append(protein.to_pdb(cur_protein)) | |||||
| logger.info('plddts:' + str(plddts)) | |||||
| model_name = self.cfg['pipeline']['model_name'] | |||||
| score_name = f'{model_name}' | |||||
| plddt_fname = score_name + '_plddt.json' | |||||
| with open(os.path.join(output_dir, plddt_fname), 'w') as f: | |||||
| json.dump(plddts, f, indent=4) | |||||
| if ptms: | |||||
| logger.info('ptms' + str(ptms)) | |||||
| ptm_fname = score_name + '_ptm.json' | |||||
| with open(os.path.join(output_dir, ptm_fname), 'w') as f: | |||||
| json.dump(ptms, f, indent=4) | |||||
| return pdbs | |||||
| def postprocess(self, inputs: Dict[str, Tensor], **postprocess_params): | |||||
| return inputs | |||||
| @@ -0,0 +1,20 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| from typing import TYPE_CHECKING | |||||
| from modelscope.utils.import_utils import LazyImportModule | |||||
| if TYPE_CHECKING: | |||||
| from .unifold import (UniFoldPreprocessor) | |||||
| else: | |||||
| _import_structure = {'unifold': ['UniFoldPreprocessor']} | |||||
| import sys | |||||
| sys.modules[__name__] = LazyImportModule( | |||||
| __name__, | |||||
| globals()['__file__'], | |||||
| _import_structure, | |||||
| module_spec=__spec__, | |||||
| extra_objects={}, | |||||
| ) | |||||
| @@ -0,0 +1,569 @@ | |||||
| # The Uni-fold implementation is also open-sourced by the authors under Apache-2.0 license, | |||||
| # and is publicly available at https://github.com/dptech-corp/Uni-Fold. | |||||
| import gzip | |||||
| import hashlib | |||||
| import logging | |||||
| import os | |||||
| import pickle | |||||
| import random | |||||
| import re | |||||
| import tarfile | |||||
| import time | |||||
| from pathlib import Path | |||||
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union | |||||
| from unittest import result | |||||
| import json | |||||
| import numpy as np | |||||
| import requests | |||||
| import torch | |||||
| from tqdm import tqdm | |||||
| from modelscope.metainfo import Preprocessors | |||||
| from modelscope.models.science.unifold.data import protein, residue_constants | |||||
| from modelscope.models.science.unifold.data.protein import PDB_CHAIN_IDS | |||||
| from modelscope.models.science.unifold.data.utils import compress_features | |||||
| from modelscope.models.science.unifold.msa import parsers, pipeline, templates | |||||
| from modelscope.models.science.unifold.msa.tools import hhsearch | |||||
| from modelscope.models.science.unifold.msa.utils import divide_multi_chains | |||||
| from modelscope.preprocessors.base import Preprocessor | |||||
| from modelscope.preprocessors.builder import PREPROCESSORS | |||||
| from modelscope.utils.constant import Fields | |||||
| __all__ = [ | |||||
| 'UniFoldPreprocessor', | |||||
| ] | |||||
| TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]' | |||||
| DEFAULT_API_SERVER = 'https://api.colabfold.com' | |||||
| def run_mmseqs2( | |||||
| x, | |||||
| prefix, | |||||
| use_env=True, | |||||
| use_templates=False, | |||||
| use_pairing=False, | |||||
| host_url='https://api.colabfold.com') -> Tuple[List[str], List[str]]: | |||||
| submission_endpoint = 'ticket/pair' if use_pairing else 'ticket/msa' | |||||
| def submit(seqs, mode, N=101): | |||||
| n, query = N, '' | |||||
| for seq in seqs: | |||||
| query += f'>{n}\n{seq}\n' | |||||
| n += 1 | |||||
| res = requests.post( | |||||
| f'{host_url}/{submission_endpoint}', | |||||
| data={ | |||||
| 'q': query, | |||||
| 'mode': mode | |||||
| }) | |||||
| try: | |||||
| out = res.json() | |||||
| except ValueError: | |||||
| out = {'status': 'ERROR'} | |||||
| return out | |||||
| def status(ID): | |||||
| res = requests.get(f'{host_url}/ticket/{ID}') | |||||
| try: | |||||
| out = res.json() | |||||
| except ValueError: | |||||
| out = {'status': 'ERROR'} | |||||
| return out | |||||
| def download(ID, path): | |||||
| res = requests.get(f'{host_url}/result/download/{ID}') | |||||
| with open(path, 'wb') as out: | |||||
| out.write(res.content) | |||||
| # process input x | |||||
| seqs = [x] if isinstance(x, str) else x | |||||
| mode = 'env' | |||||
| if use_pairing: | |||||
| mode = '' | |||||
| use_templates = False | |||||
| use_env = False | |||||
| # define path | |||||
| path = f'{prefix}' | |||||
| if not os.path.isdir(path): | |||||
| os.mkdir(path) | |||||
| # call mmseqs2 api | |||||
| tar_gz_file = f'{path}/out_{mode}.tar.gz' | |||||
| N, REDO = 101, True | |||||
| # deduplicate and keep track of order | |||||
| seqs_unique = [] | |||||
| # TODO this might be slow for large sets | |||||
| [seqs_unique.append(x) for x in seqs if x not in seqs_unique] | |||||
| Ms = [N + seqs_unique.index(seq) for seq in seqs] | |||||
| # lets do it! | |||||
| if not os.path.isfile(tar_gz_file): | |||||
| TIME_ESTIMATE = 150 * len(seqs_unique) | |||||
| with tqdm(total=TIME_ESTIMATE, bar_format=TQDM_BAR_FORMAT) as pbar: | |||||
| while REDO: | |||||
| pbar.set_description('SUBMIT') | |||||
| # Resubmit job until it goes through | |||||
| out = submit(seqs_unique, mode, N) | |||||
| while out['status'] in ['UNKNOWN', 'RATELIMIT']: | |||||
| sleep_time = 5 + random.randint(0, 5) | |||||
| # logger.error(f"Sleeping for {sleep_time}s. Reason: {out['status']}") | |||||
| # resubmit | |||||
| time.sleep(sleep_time) | |||||
| out = submit(seqs_unique, mode, N) | |||||
| if out['status'] == 'ERROR': | |||||
| error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' | |||||
| error = error + 'If error persists, please try again an hour later.' | |||||
| raise Exception(error) | |||||
| if out['status'] == 'MAINTENANCE': | |||||
| raise Exception( | |||||
| 'MMseqs2 API is undergoing maintenance. Please try again in a few minutes.' | |||||
| ) | |||||
| # wait for job to finish | |||||
| ID, TIME = out['id'], 0 | |||||
| pbar.set_description(out['status']) | |||||
| while out['status'] in ['UNKNOWN', 'RUNNING', 'PENDING']: | |||||
| t = 5 + random.randint(0, 5) | |||||
| # logger.error(f"Sleeping for {t}s. Reason: {out['status']}") | |||||
| time.sleep(t) | |||||
| out = status(ID) | |||||
| pbar.set_description(out['status']) | |||||
| if out['status'] == 'RUNNING': | |||||
| TIME += t | |||||
| pbar.update(n=t) | |||||
| if out['status'] == 'COMPLETE': | |||||
| if TIME < TIME_ESTIMATE: | |||||
| pbar.update(n=(TIME_ESTIMATE - TIME)) | |||||
| REDO = False | |||||
| if out['status'] == 'ERROR': | |||||
| REDO = False | |||||
| error = 'MMseqs2 API is giving errors. Please confirm your input is a valid protein sequence.' | |||||
| error = error + 'If error persists, please try again an hour later.' | |||||
| raise Exception(error) | |||||
| # Download results | |||||
| download(ID, tar_gz_file) | |||||
| # prep list of a3m files | |||||
| if use_pairing: | |||||
| a3m_files = [f'{path}/pair.a3m'] | |||||
| else: | |||||
| a3m_files = [f'{path}/uniref.a3m'] | |||||
| if use_env: | |||||
| a3m_files.append(f'{path}/bfd.mgnify30.metaeuk30.smag30.a3m') | |||||
| # extract a3m files | |||||
| if any(not os.path.isfile(a3m_file) for a3m_file in a3m_files): | |||||
| with tarfile.open(tar_gz_file) as tar_gz: | |||||
| tar_gz.extractall(path) | |||||
| # templates | |||||
| if use_templates: | |||||
| templates = {} | |||||
| with open(f'{path}/pdb70.m8', 'r') as f: | |||||
| lines = f.readlines() | |||||
| for line in lines: | |||||
| p = line.rstrip().split() | |||||
| M, pdb, _, _ = p[0], p[1], p[2], p[10] # qid, e_value | |||||
| M = int(M) | |||||
| if M not in templates: | |||||
| templates[M] = [] | |||||
| templates[M].append(pdb) | |||||
| template_paths = {} | |||||
| for k, TMPL in templates.items(): | |||||
| TMPL_PATH = f'{prefix}/templates_{k}' | |||||
| if not os.path.isdir(TMPL_PATH): | |||||
| os.mkdir(TMPL_PATH) | |||||
| TMPL_LINE = ','.join(TMPL[:20]) | |||||
| os.system( | |||||
| f'curl -s -L {host_url}/template/{TMPL_LINE} | tar xzf - -C {TMPL_PATH}/' | |||||
| ) | |||||
| os.system( | |||||
| f'cp {TMPL_PATH}/pdb70_a3m.ffindex {TMPL_PATH}/pdb70_cs219.ffindex' | |||||
| ) | |||||
| os.system(f'touch {TMPL_PATH}/pdb70_cs219.ffdata') | |||||
| template_paths[k] = TMPL_PATH | |||||
| # gather a3m lines | |||||
| a3m_lines = {} | |||||
| for a3m_file in a3m_files: | |||||
| update_M, M = True, None | |||||
| with open(a3m_file, 'r') as f: | |||||
| lines = f.readlines() | |||||
| for line in lines: | |||||
| if len(line) > 0: | |||||
| if '\x00' in line: | |||||
| line = line.replace('\x00', '') | |||||
| update_M = True | |||||
| if line.startswith('>') and update_M: | |||||
| M = int(line[1:].rstrip()) | |||||
| update_M = False | |||||
| if M not in a3m_lines: | |||||
| a3m_lines[M] = [] | |||||
| a3m_lines[M].append(line) | |||||
| # return results | |||||
| a3m_lines = [''.join(a3m_lines[n]) for n in Ms] | |||||
| if use_templates: | |||||
| template_paths_ = [] | |||||
| for n in Ms: | |||||
| if n not in template_paths: | |||||
| template_paths_.append(None) | |||||
| # print(f"{n-N}\tno_templates_found") | |||||
| else: | |||||
| template_paths_.append(template_paths[n]) | |||||
| template_paths = template_paths_ | |||||
| return (a3m_lines, template_paths) if use_templates else a3m_lines | |||||
| def get_null_template(query_sequence: Union[List[str], str], | |||||
| num_temp: int = 1) -> Dict[str, Any]: | |||||
| ln = ( | |||||
| len(query_sequence) if isinstance(query_sequence, str) else sum( | |||||
| len(s) for s in query_sequence)) | |||||
| output_templates_sequence = 'A' * ln | |||||
| # output_confidence_scores = np.full(ln, 1.0) | |||||
| templates_all_atom_positions = np.zeros( | |||||
| (ln, templates.residue_constants.atom_type_num, 3)) | |||||
| templates_all_atom_masks = np.zeros( | |||||
| (ln, templates.residue_constants.atom_type_num)) | |||||
| templates_aatype = templates.residue_constants.sequence_to_onehot( | |||||
| output_templates_sequence, | |||||
| templates.residue_constants.HHBLITS_AA_TO_ID) | |||||
| template_features = { | |||||
| 'template_all_atom_positions': | |||||
| np.tile(templates_all_atom_positions[None], [num_temp, 1, 1, 1]), | |||||
| 'template_all_atom_masks': | |||||
| np.tile(templates_all_atom_masks[None], [num_temp, 1, 1]), | |||||
| 'template_sequence': ['none'.encode()] * num_temp, | |||||
| 'template_aatype': | |||||
| np.tile(np.array(templates_aatype)[None], [num_temp, 1, 1]), | |||||
| 'template_domain_names': ['none'.encode()] * num_temp, | |||||
| 'template_sum_probs': | |||||
| np.zeros([num_temp], dtype=np.float32), | |||||
| } | |||||
| return template_features | |||||
| def get_template(a3m_lines: str, template_path: str, | |||||
| query_sequence: str) -> Dict[str, Any]: | |||||
| template_featurizer = templates.HhsearchHitFeaturizer( | |||||
| mmcif_dir=template_path, | |||||
| max_template_date='2100-01-01', | |||||
| max_hits=20, | |||||
| kalign_binary_path='kalign', | |||||
| release_dates_path=None, | |||||
| obsolete_pdbs_path=None, | |||||
| ) | |||||
| hhsearch_pdb70_runner = hhsearch.HHSearch( | |||||
| binary_path='hhsearch', databases=[f'{template_path}/pdb70']) | |||||
| hhsearch_result = hhsearch_pdb70_runner.query(a3m_lines) | |||||
| hhsearch_hits = pipeline.parsers.parse_hhr(hhsearch_result) | |||||
| templates_result = template_featurizer.get_templates( | |||||
| query_sequence=query_sequence, hits=hhsearch_hits) | |||||
| return dict(templates_result.features) | |||||
| @PREPROCESSORS.register_module( | |||||
| Fields.science, module_name=Preprocessors.unifold_preprocessor) | |||||
| class UniFoldPreprocessor(Preprocessor): | |||||
| def __init__(self, **cfg): | |||||
| self.symmetry_group = cfg['symmetry_group'] # "C1" | |||||
| if not self.symmetry_group: | |||||
| self.symmetry_group = None | |||||
| self.MIN_SINGLE_SEQUENCE_LENGTH = 16 # TODO: change to cfg | |||||
| self.MAX_SINGLE_SEQUENCE_LENGTH = 1000 | |||||
| self.MAX_MULTIMER_LENGTH = 1000 | |||||
| self.jobname = 'unifold' | |||||
| self.output_dir_base = './unifold-predictions' | |||||
| os.makedirs(self.output_dir_base, exist_ok=True) | |||||
| def clean_and_validate_sequence(self, input_sequence: str, min_length: int, | |||||
| max_length: int) -> str: | |||||
| clean_sequence = input_sequence.translate( | |||||
| str.maketrans('', '', ' \n\t')).upper() | |||||
| aatypes = set(residue_constants.restypes) # 20 standard aatypes. | |||||
| if not set(clean_sequence).issubset(aatypes): | |||||
| raise ValueError( | |||||
| f'Input sequence contains non-amino acid letters: ' | |||||
| f'{set(clean_sequence) - aatypes}. AlphaFold only supports 20 standard ' | |||||
| 'amino acids as inputs.') | |||||
| if len(clean_sequence) < min_length: | |||||
| raise ValueError( | |||||
| f'Input sequence is too short: {len(clean_sequence)} amino acids, ' | |||||
| f'while the minimum is {min_length}') | |||||
| if len(clean_sequence) > max_length: | |||||
| raise ValueError( | |||||
| f'Input sequence is too long: {len(clean_sequence)} amino acids, while ' | |||||
| f'the maximum is {max_length}. You may be able to run it with the full ' | |||||
| f'Uni-Fold system depending on your resources (system memory, ' | |||||
| f'GPU memory).') | |||||
| return clean_sequence | |||||
| def validate_input(self, input_sequences: Sequence[str], | |||||
| symmetry_group: str, min_length: int, max_length: int, | |||||
| max_multimer_length: int) -> Tuple[Sequence[str], bool]: | |||||
| """Validates and cleans input sequences and determines which model to use.""" | |||||
| sequences = [] | |||||
| for input_sequence in input_sequences: | |||||
| if input_sequence.strip(): | |||||
| input_sequence = self.clean_and_validate_sequence( | |||||
| input_sequence=input_sequence, | |||||
| min_length=min_length, | |||||
| max_length=max_length) | |||||
| sequences.append(input_sequence) | |||||
| if symmetry_group is not None and symmetry_group != 'C1': | |||||
| if symmetry_group.startswith( | |||||
| 'C') and symmetry_group[1:].isnumeric(): | |||||
| print( | |||||
| f'Using UF-Symmetry with group {symmetry_group}. If you do not ' | |||||
| f'want to use UF-Symmetry, please use `C1` and copy the AU ' | |||||
| f'sequences to the count in the assembly.') | |||||
| is_multimer = (len(sequences) > 1) | |||||
| return sequences, is_multimer, symmetry_group | |||||
| else: | |||||
| raise ValueError( | |||||
| f'UF-Symmetry does not support symmetry group ' | |||||
| f'{symmetry_group} currently. Cyclic groups (Cx) are ' | |||||
| f'supported only.') | |||||
| elif len(sequences) == 1: | |||||
| print('Using the single-chain model.') | |||||
| return sequences, False, None | |||||
| elif len(sequences) > 1: | |||||
| total_multimer_length = sum([len(seq) for seq in sequences]) | |||||
| if total_multimer_length > max_multimer_length: | |||||
| raise ValueError( | |||||
| f'The total length of multimer sequences is too long: ' | |||||
| f'{total_multimer_length}, while the maximum is ' | |||||
| f'{max_multimer_length}. Please use the full AlphaFold ' | |||||
| f'system for long multimers.') | |||||
| print(f'Using the multimer model with {len(sequences)} sequences.') | |||||
| return sequences, True, None | |||||
| else: | |||||
| raise ValueError( | |||||
| 'No input amino acid sequence provided, please provide at ' | |||||
| 'least one sequence.') | |||||
| def add_hash(self, x, y): | |||||
| return x + '_' + hashlib.sha1(y.encode()).hexdigest()[:5] | |||||
| def get_msa_and_templates( | |||||
| self, | |||||
| jobname: str, | |||||
| query_seqs_unique: Union[str, List[str]], | |||||
| result_dir: Path, | |||||
| msa_mode: str, | |||||
| use_templates: bool, | |||||
| homooligomers_num: int = 1, | |||||
| host_url: str = DEFAULT_API_SERVER, | |||||
| ) -> Tuple[Optional[List[str]], Optional[List[str]], List[str], List[int], | |||||
| List[Dict[str, Any]]]: | |||||
| use_env = msa_mode == 'MMseqs2' | |||||
| template_features = [] | |||||
| if use_templates: | |||||
| a3m_lines_mmseqs2, template_paths = run_mmseqs2( | |||||
| query_seqs_unique, | |||||
| str(result_dir.joinpath(jobname)), | |||||
| use_env, | |||||
| use_templates=True, | |||||
| host_url=host_url, | |||||
| ) | |||||
| if template_paths is None: | |||||
| for index in range(0, len(query_seqs_unique)): | |||||
| template_feature = get_null_template( | |||||
| query_seqs_unique[index]) | |||||
| template_features.append(template_feature) | |||||
| else: | |||||
| for index in range(0, len(query_seqs_unique)): | |||||
| if template_paths[index] is not None: | |||||
| template_feature = get_template( | |||||
| a3m_lines_mmseqs2[index], | |||||
| template_paths[index], | |||||
| query_seqs_unique[index], | |||||
| ) | |||||
| if len(template_feature['template_domain_names']) == 0: | |||||
| template_feature = get_null_template( | |||||
| query_seqs_unique[index]) | |||||
| else: | |||||
| template_feature = get_null_template( | |||||
| query_seqs_unique[index]) | |||||
| template_features.append(template_feature) | |||||
| else: | |||||
| for index in range(0, len(query_seqs_unique)): | |||||
| template_feature = get_null_template(query_seqs_unique[index]) | |||||
| template_features.append(template_feature) | |||||
| if msa_mode == 'single_sequence': | |||||
| a3m_lines = [] | |||||
| num = 101 | |||||
| for i, seq in enumerate(query_seqs_unique): | |||||
| a3m_lines.append('>' + str(num + i) + '\n' + seq) | |||||
| else: | |||||
| # find normal a3ms | |||||
| a3m_lines = run_mmseqs2( | |||||
| query_seqs_unique, | |||||
| str(result_dir.joinpath(jobname)), | |||||
| use_env, | |||||
| use_pairing=False, | |||||
| host_url=host_url, | |||||
| ) | |||||
| if len(query_seqs_unique) > 1: | |||||
| # find paired a3m if not a homooligomers | |||||
| paired_a3m_lines = run_mmseqs2( | |||||
| query_seqs_unique, | |||||
| str(result_dir.joinpath(jobname)), | |||||
| use_env, | |||||
| use_pairing=True, | |||||
| host_url=host_url, | |||||
| ) | |||||
| else: | |||||
| num = 101 | |||||
| paired_a3m_lines = [] | |||||
| for i in range(0, homooligomers_num): | |||||
| paired_a3m_lines.append('>' + str(num + i) + '\n' | |||||
| + query_seqs_unique[0] + '\n') | |||||
| return ( | |||||
| a3m_lines, | |||||
| paired_a3m_lines, | |||||
| template_features, | |||||
| ) | |||||
| def __call__(self, data: Union[str, Tuple]): | |||||
| if isinstance(data, str): | |||||
| data = [data, '', '', ''] | |||||
| basejobname = ''.join(data) | |||||
| basejobname = re.sub(r'\W+', '', basejobname) | |||||
| target_id = self.add_hash(self.jobname, basejobname) | |||||
| sequences, is_multimer, _ = self.validate_input( | |||||
| input_sequences=data, | |||||
| symmetry_group=self.symmetry_group, | |||||
| min_length=self.MIN_SINGLE_SEQUENCE_LENGTH, | |||||
| max_length=self.MAX_SINGLE_SEQUENCE_LENGTH, | |||||
| max_multimer_length=self.MAX_MULTIMER_LENGTH) | |||||
| descriptions = [ | |||||
| '> ' + target_id + ' seq' + str(ii) | |||||
| for ii in range(len(sequences)) | |||||
| ] | |||||
| if is_multimer: | |||||
| divide_multi_chains(target_id, self.output_dir_base, sequences, | |||||
| descriptions) | |||||
| s = [] | |||||
| for des, seq in zip(descriptions, sequences): | |||||
| s += [des, seq] | |||||
| unique_sequences = [] | |||||
| [ | |||||
| unique_sequences.append(x) for x in sequences | |||||
| if x not in unique_sequences | |||||
| ] | |||||
| if len(unique_sequences) == 1: | |||||
| homooligomers_num = len(sequences) | |||||
| else: | |||||
| homooligomers_num = 1 | |||||
| with open(f'{self.jobname}.fasta', 'w') as f: | |||||
| f.write('\n'.join(s)) | |||||
| result_dir = Path(self.output_dir_base) | |||||
| output_dir = os.path.join(self.output_dir_base, target_id) | |||||
| # msa_mode = 'single_sequence' | |||||
| msa_mode = 'MMseqs2' | |||||
| use_templates = True | |||||
| unpaired_msa, paired_msa, template_results = self.get_msa_and_templates( | |||||
| target_id, | |||||
| unique_sequences, | |||||
| result_dir=result_dir, | |||||
| msa_mode=msa_mode, | |||||
| use_templates=use_templates, | |||||
| homooligomers_num=homooligomers_num) | |||||
| features = [] | |||||
| pair_features = [] | |||||
| for idx, seq in enumerate(unique_sequences): | |||||
| chain_id = PDB_CHAIN_IDS[idx] | |||||
| sequence_features = pipeline.make_sequence_features( | |||||
| sequence=seq, | |||||
| description=f'> {self.jobname} seq {chain_id}', | |||||
| num_res=len(seq)) | |||||
| monomer_msa = parsers.parse_a3m(unpaired_msa[idx]) | |||||
| msa_features = pipeline.make_msa_features([monomer_msa]) | |||||
| template_features = template_results[idx] | |||||
| feature_dict = { | |||||
| **sequence_features, | |||||
| **msa_features, | |||||
| **template_features | |||||
| } | |||||
| feature_dict = compress_features(feature_dict) | |||||
| features_output_path = os.path.join( | |||||
| output_dir, '{}.feature.pkl.gz'.format(chain_id)) | |||||
| pickle.dump( | |||||
| feature_dict, | |||||
| gzip.GzipFile(features_output_path, 'wb'), | |||||
| protocol=4) | |||||
| features.append(feature_dict) | |||||
| if is_multimer: | |||||
| multimer_msa = parsers.parse_a3m(paired_msa[idx]) | |||||
| pair_features = pipeline.make_msa_features([multimer_msa]) | |||||
| pair_feature_dict = compress_features(pair_features) | |||||
| uniprot_output_path = os.path.join( | |||||
| output_dir, '{}.uniprot.pkl.gz'.format(chain_id)) | |||||
| pickle.dump( | |||||
| pair_feature_dict, | |||||
| gzip.GzipFile(uniprot_output_path, 'wb'), | |||||
| protocol=4, | |||||
| ) | |||||
| pair_features.append(pair_feature_dict) | |||||
| # return features, pair_features, target_id | |||||
| return { | |||||
| 'features': features, | |||||
| 'pair_features': pair_features, | |||||
| 'target_id': target_id, | |||||
| 'is_multimer': is_multimer, | |||||
| } | |||||
| if __name__ == '__main__': | |||||
| proc = UniFoldPreprocessor() | |||||
| protein_example = 'LILNLRGGAFVSNTQITMADKQKKFINEIQEGDLVRSYSITDETFQQNAVTSIVKHEADQLCQINFGKQHVVC' + \ | |||||
| 'TVNHRFYDPESKLWKSVCPHPGSGISFLKKYDYLLSEEGEKLQITEIKTFTTKQPVFIYHIQVENNHNFFANGVLAHAMQVSI' | |||||
| features, pair_features = proc.__call__(protein_example) | |||||
| import ipdb | |||||
| ipdb.set_trace() | |||||
| @@ -9,6 +9,7 @@ class Fields(object): | |||||
| nlp = 'nlp' | nlp = 'nlp' | ||||
| audio = 'audio' | audio = 'audio' | ||||
| multi_modal = 'multi-modal' | multi_modal = 'multi-modal' | ||||
| science = 'science' | |||||
| class CVTasks(object): | class CVTasks(object): | ||||
| @@ -151,6 +152,10 @@ class MultiModalTasks(object): | |||||
| image_text_retrieval = 'image-text-retrieval' | image_text_retrieval = 'image-text-retrieval' | ||||
| class ScienceTasks(object): | |||||
| protein_structure = 'protein-structure' | |||||
| class TasksIODescriptions(object): | class TasksIODescriptions(object): | ||||
| image_to_image = 'image_to_image', | image_to_image = 'image_to_image', | ||||
| images_to_image = 'images_to_image', | images_to_image = 'images_to_image', | ||||
| @@ -167,7 +172,7 @@ class TasksIODescriptions(object): | |||||
| generative_multi_modal_embedding = 'generative_multi_modal_embedding' | generative_multi_modal_embedding = 'generative_multi_modal_embedding' | ||||
| class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||||
| class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks, ScienceTasks): | |||||
| """ Names for tasks supported by modelscope. | """ Names for tasks supported by modelscope. | ||||
| Holds the standard task name to use for identifying different tasks. | Holds the standard task name to use for identifying different tasks. | ||||
| @@ -196,6 +201,10 @@ class Tasks(CVTasks, NLPTasks, AudioTasks, MultiModalTasks): | |||||
| getattr(Tasks, attr) for attr in dir(MultiModalTasks) | getattr(Tasks, attr) for attr in dir(MultiModalTasks) | ||||
| if not attr.startswith('__') | if not attr.startswith('__') | ||||
| ], | ], | ||||
| Fields.science: [ | |||||
| getattr(Tasks, attr) for attr in dir(ScienceTasks) | |||||
| if not attr.startswith('__') | |||||
| ], | |||||
| } | } | ||||
| for field, tasks in field_dict.items(): | for field, tasks in field_dict.items(): | ||||
| @@ -0,0 +1,6 @@ | |||||
| iopath | |||||
| lmdb | |||||
| ml_collections | |||||
| scipy | |||||
| tensorboardX | |||||
| tokenizers | |||||
| @@ -0,0 +1,34 @@ | |||||
| # Copyright (c) Alibaba, Inc. and its affiliates. | |||||
| import unittest | |||||
| from modelscope.hub.snapshot_download import snapshot_download | |||||
| from modelscope.pipelines import pipeline | |||||
| from modelscope.utils.constant import Tasks | |||||
| from modelscope.utils.demo_utils import DemoCompatibilityCheck | |||||
| from modelscope.utils.test_utils import test_level | |||||
| class UnifoldProteinStructureTest(unittest.TestCase, DemoCompatibilityCheck): | |||||
| def setUp(self) -> None: | |||||
| self.task = Tasks.protein_structure | |||||
| self.model_id = 'DPTech/uni-fold-monomer' | |||||
| self.model_id_multimer = 'DPTech/uni-fold-multimer' | |||||
| self.protein = 'MGLPKKALKESQLQFLTAGTAVSDSSHQTYKVSFIENGVIKNAFYKKLDPKNHYPELLAKISVAVSLFKRIFQGRRSAEERLVFDD' | |||||
| self.protein_multimer = 'GAMGLPEEPSSPQESTLKALSLYEAHLSSYIMYLQTFLVKTKQKVNNKNYPEFTLFDTSKLKKDQTLKSIKT' + \ | |||||
| 'NIAALKNHIDKIKPIAMQIYKKYSKNIP' | |||||
| @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') | |||||
| def test_run_by_direct_model_download(self): | |||||
| model_dir = snapshot_download(self.model_id) | |||||
| mono_pipeline_ins = pipeline(task=self.task, model=model_dir) | |||||
| _ = mono_pipeline_ins(self.protein) | |||||
| model_dir1 = snapshot_download(self.model_id_multimer) | |||||
| multi_pipeline_ins = pipeline(task=self.task, model=model_dir1) | |||||
| _ = multi_pipeline_ins(self.protein_multimer) | |||||
| if __name__ == '__main__': | |||||
| unittest.main() | |||||