Skip to content

Commit

Permalink
fix issue with re-adding unpaired seqs
Browse files Browse the repository at this point in the history
  • Loading branch information
psathyrella committed Jan 28, 2024
1 parent 21b2cfe commit 2f3f1c8
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 14 deletions.
4 changes: 2 additions & 2 deletions bin/partis
Original file line number Diff line number Diff line change
Expand Up @@ -805,14 +805,14 @@ def run_all_loci(args, ig_or_tr='ig'):
hloc = utils.heavy_locus(ig_or_tr)
if len(lpfos['glfos']) == 0:
return lpfos
tmp_ptn = utils.get_deduplicated_partitions([lpfos['cpaths'][hloc].best()], antn_list=lpfos['antn_lists'][hloc] if not args.dont_calculate_annotations else None, glfo=lpfos['glfos'][hloc], debug=args.debug)[0] # have to remove duplicates from heavy partitions and annotations (since seqs that we don't have good pairing info for get put in both light chain dirs, so appear twice in concat'd heavy chain)
tmp_ptn = utils.get_deduplicated_partitions([lpfos['cpaths'][hloc].best()], antn_list=lpfos['antn_lists'][hloc] if not args.dont_calculate_annotations else None, glfo=lpfos['glfos'][hloc], debug=args.debug_paired_clustering)[0] # have to remove duplicates from heavy partitions and annotations (since seqs that we don't have good pairing info for get put in both light chain dirs, so appear twice in concat'd heavy chain)
lpfos['cpaths'][hloc] = ClusterPath(partition=tmp_ptn, seed_unique_id=seedid(hloc))
return lpfos
# ----------------------------------------------------------------------------------------
def combine_simu_chains():
lp_infos = {}
for lpair in spairs():
lpfos = paircluster.read_locus_output_files(lpair, getofn, lpair=lpair, dont_add_implicit_info=not args.debug_paired_clustering, dbgstr='simulation', debug=args.debug)
lpfos = paircluster.read_locus_output_files(lpair, getofn, lpair=lpair, dont_add_implicit_info=not args.debug_paired_clustering, dbgstr='simulation', debug=args.debug_paired_clustering)
lp_infos[tuple(lpair)] = lpfos
if None in list(lpfos.values()):
continue
Expand Down
35 changes: 23 additions & 12 deletions python/paircluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,43 +864,48 @@ def print_cf_ham(iun, sorted_hdists): # this is just to print some dbg to compa
sys.stdout.flush()
if debug:
print(' removing bad/un-paired seqs')
print(' N N no other non-')
print(' before removed info light recip')
print(' N N no other non- original')
print(' before removed info light recip cluster')
for tch in sorted(ploci):
new_partition, new_antn_list = [], []
for iclust, cluster in enumerate(cpaths[ploci[tch]].best()):
cline = antn_dicts[ploci[tch]][':'.join(cluster)]
paired_iseqs = [i for i, pds in enumerate(cline['paired-uids']) if len(pds) > 0]
iseqs_to_remove = []
n_no_info, n_other_light, n_non_reciprocal = 0, 0, 0 # just for dbg NOTE n_other_light are the only ones we *really* want to remove, since they're h seqs paired with the other light chain, whereas the other two categories we eventually want to re-add since we're not sure who they're paired with
unpaired_to_add = []
for iseq, uid in enumerate(cline['unique_ids']):
pids = cline['paired-uids'][iseq]
if len(pids) == 0: # no pairing info
iseqs_to_remove.append(iseq)
add_unpaired(cline, iseq, uid, paired_iseqs)
unpaired_to_add.append((iseq, uid))
n_no_info += 1
elif len(pids) > 1: # should've all been removed by pair info cleaning
raise Exception('multiple paired uids for \'%s\': %s' % (uid, pids))
else:
# print(' ', uid, tch, pids, all_loci[utils.get_single_entry(pids)], ploci['l'])
if tch == 'h' and all_loci[utils.get_single_entry(pids)] != ploci['l']: # if it's the other light chain
iseqs_to_remove.append(iseq)
n_other_light += 1
else: # also remove any non-reciprocal pairings (I think this will still miss any whose partner was removed) NOTE it would be nice to enforce reciprocal pairings in pair info cleaning, but atm i think we can't look at both chains at once in that fcn UPDATE i think we do this now
if all_pids[uid] not in all_pids or all_pids[all_pids[uid]] != uid: # if uid's pid isn't in all_pids, or if it is but it's a different uid
iseqs_to_remove.append(iseq)
add_unpaired(cline, iseq, uid, paired_iseqs)
unpaired_to_add.append((iseq, uid))
n_non_reciprocal += 1
if n_no_info + n_other_light < len(cluster): # if there's seqs in the cluster that *aren't* either unpaired (n_no_info) or paired with the other light chain (n_other_light), we want to add the unpaired ones as unpaired seqs (i.e. if n_no_info + n_other_light *equals* len(cluster), we *don't* want to add them, since they should get added in/during the other light chain, and if they're added to both, they'll get deduplicated poorly later)
for upi, upid in unpaired_to_add:
add_unpaired(cline, upi, upid, paired_iseqs)
iseqs_to_keep = [i for i in range(len(cline['unique_ids'])) if i not in iseqs_to_remove]
if len(cline) > 1:
process_unpaired(cline, ploci[tch], iseqs_to_keep, paired_iseqs) # have to go back after finishing cluster since only now do we know who we ended up keeping
process_unpaired(cline, ploci[tch], iseqs_to_keep, paired_iseqs) # have to go back after finishing cluster since only now do we know who we ended up keeping
if len(iseqs_to_keep) > 0:
new_partition.append([cluster[i] for i in iseqs_to_keep])
new_cline = utils.get_non_implicit_copy(cline)
utils.restrict_to_iseqs(new_cline, iseqs_to_keep, glfos[ploci[tch]]) # note that this calls utils.add_implicit_info()
new_antn_list.append(new_cline)
if debug:
def fstr(v): return '' if v==0 else '%d'%v
print(' %s %3d %3s %3s %3s %3s' % (utils.locstr(ploci[tch]) if iclust==0 else ' ', len(cline['unique_ids']), fstr(len(iseqs_to_remove)), fstr(n_no_info), fstr(n_other_light), fstr(n_non_reciprocal)))
nrstr = utils.color('red' if len(iseqs_to_keep)==0 else None, fstr(len(iseqs_to_remove)), width=3)
print(' %s %3d %s %3s %3s %3s %s' % (utils.locstr(ploci[tch]) if iclust==0 else ' ', len(cline['unique_ids']), nrstr, fstr(n_no_info), fstr(n_other_light), fstr(n_non_reciprocal), ' '.join(cluster)))
lp_cpaths[ploci[tch]] = ClusterPath(seed_unique_id=cpaths[ploci[tch]].seed_unique_id, partition=new_partition)
lp_antn_lists[ploci[tch]] = new_antn_list

Expand Down Expand Up @@ -1640,9 +1645,10 @@ def get_rc_indices(): # separate fcn just to make profiling easier (should mayb
for fid in final_partition[iftmp]:
fclust_indices[fid] = iftmp
# ----------------------------------------------------------------------------------------
def re_add_unpaired(joint_partitions, unpaired_seqs):
if huge_dbg:
def re_add_unpaired(joint_partitions, unpaired_seqs, tdbg=False):
if huge_dbg or tdbg:
print(' re-adding %d unpaired' % sum(len(lseqs) for lseqs in unpaired_seqs.values()))
if huge_dbg:
print(' finished:', end=' ')
sys.stdout.flush()
ihuge = 0
Expand Down Expand Up @@ -1672,14 +1678,19 @@ def re_add_unpaired(joint_partitions, unpaired_seqs):
joint_partitions[tch][ijclusts[0]].append(upid)
jp_indices[upid] = ijclusts[0]
n_added[tch]['existing-cluster'] += 1
if tdbg:
print(' %s joint clusters with re-added unpaired seqs in %s:' % (utils.locstr(ltmp), utils.color('red', 'red')))
for ijoint, uipairs in itertools.groupby(sorted(jp_indices.items(), key=operator.itemgetter(1)), key=operator.itemgetter(1)):
print(' %s' % ' '.join(utils.color('red' if u in unpaired_seqs[ltmp] else None, u) for u in joint_partitions[tch][ijoint]))
if huge_dbg:
print('')
totstr = ' '.join('%s %d'%(utils.locstr(ploci[tch]), sum(len(c) for c in joint_partitions[tch])) for tch in sorted(ploci))
print(' re-added unpaired seqs (%s) to give total seqs in joint partitions: %s' % (', '.join('%s %d'%(utils.locstr(ploci[tch]), sum(nfo.values())) for tch, nfo in n_added.items()), totstr))
sys.stdout.flush()
# print ' singleton new cluster existing cluster'
# print(' new existing')
# print(' singleton cluster cluster')
# for tch in 'hl':
# print ' %s %4d %4d %4d' % (utils.locstr(ploci[tch]), n_added[tch]['singleton'], n_added[tch]['new-cluster'], n_added[tch]['existing-cluster'])
# print(' %s %4d %4d %4d' % (utils.locstr(ploci[tch]), n_added[tch]['singleton'], n_added[tch]['new-cluster'], n_added[tch]['existing-cluster']))
# ----------------------------------------------------------------------------------------
print(' merging %s partitions' % '+'.join(list(ploci.values())))
sys.stdout.flush()
Expand Down Expand Up @@ -1781,7 +1792,7 @@ def getcstr(clist): return ' '.join(str(len(c)) for c in clist)
untranslate_pids(ploci, init_partitions, antn_lists, l_translations, joint_partitions, antn_dict) # NOTE code after here (at least randomly_pair_unpaired_seqs()) assumes that corresponding h/l clusters are in same order in each partition

if unpaired_seqs is not None: # it might be cleaner to have this elsewhere, but I want it to happen before we evaluate, and it's also nice to have evaluation in here
re_add_unpaired(joint_partitions, unpaired_seqs)
re_add_unpaired(joint_partitions, unpaired_seqs, tdbg=debug)

ccfs = None
if true_outfos is not None:
Expand Down

0 comments on commit 2f3f1c8

Please sign in to comment.