Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
YoelShoshan committed Feb 14, 2024
1 parent fcceb44 commit 2b8fdca
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions fusedrug/data/protein/structure/flexible_align_chains_structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jsonargparse import CLI
from typing import List, Union, Dict, Tuple
from typing import List, Union, Dict, Tuple, Optional
from Bio import Align
from tiny_openfold.utils.superimposition import superimpose

Expand All @@ -19,7 +19,8 @@ def flexible_align_chains_structure(
apply_rigid_transformation_to_dynamic_chain_ids: Union[List[Tuple], str],
static_ordered_chains: Union[List[Tuple], str],
output_pdb_filename_extentionless: str,
minimal_matching_sequence_level_chunk: int = 8,
minimal_matching_sequence_level_chunk: Optional[int] = 8,
backbone_only_based: bool = False,
###chain_id_type:str = "author_assigned",
) -> None:
"""
Expand Down Expand Up @@ -117,9 +118,15 @@ def flexible_align_chains_structure(
static_matching["atom14_gt_exists"].astype(bool),
)
# orig_atom_pos_shape = dynamic_matching["atom14_gt_positions"].shape
use_for_static = static_matching["atom14_gt_positions"]
use_for_dynamic = dynamic_matching["atom14_gt_positions"]
if backbone_only_based:
use_for_static = use_for_static[:, :4, ...]
use_for_dynamic = use_for_dynamic[:, :4, ...]

_, rmsd, rot_matrix, trans_matrix = superimpose(
static_matching["atom14_gt_positions"].reshape(-1, 3),
dynamic_matching["atom14_gt_positions"].reshape(-1, 3),
use_for_static.reshape(-1, 3),
use_for_dynamic.reshape(-1, 3),
combined_mask.reshape(-1),
verbose=True,
)
Expand Down Expand Up @@ -199,7 +206,9 @@ def _apply_indices(x: Dict, indices: np.ndarray) -> Tuple[str, np.ndarray]:


def get_alignment_indices(
target: str, query: str, minimal_matching_sequence_level_chunk: int
target: str,
query: str,
minimal_matching_sequence_level_chunk: Optional[int],
) -> Tuple[np.ndarray, np.ndarray]:
aligner = Align.PairwiseAligner()

Expand All @@ -214,7 +223,9 @@ def get_alignment_indices(
query_indices = []

for (target_start, target_end), (query_start, query_end) in zip(*alignment.aligned):
if target_end - target_start >= minimal_matching_sequence_level_chunk:
if (minimal_matching_sequence_level_chunk is None) or (
target_end - target_start >= minimal_matching_sequence_level_chunk
):
target_indices.extend(list(range(target_start, target_end)))
query_indices.extend(list(range(query_start, query_end)))

Expand Down

0 comments on commit 2b8fdca

Please sign in to comment.