Source code for gratools.Graph

# Standard library imports
import asyncio
import gzip
import logging
import sqlite3
import subprocess
import tempfile
import re
from collections import Counter, OrderedDict, defaultdict, namedtuple
from concurrent.futures import ThreadPoolExecutor, wait, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Set

# Third-party imports
import aiofiles
import aiosqlite
from numpy import mean, median, std, histogram
from pandas import set_option, DataFrame, merge
import uvloop
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pybedtools import BedTool, cleanup
from pysam import AlignedSegment, AlignmentFile, index, qualitystring_to_array, view
from rich.panel import Panel
from rich.progress import (
    BarColumn,
    Progress,
    TaskID,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
from rich.table import Table
from rich import box
from natsort import natsorted

# Local application/library specific imports
from .logger_config import shared_console  # Assuming shared_console is used for rich printing
from .useful_function import RC_TRANS, reverse_complement_string
# Apply uvloop event loop policy
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

# Configure display precision of floats for pandas
set_option("display.precision", 2)

LinkInfo = namedtuple(
    "LinkInfo",
    (
        "seg_id_1",
        "orient_seg_1",
        "orient_key_seg_1",
        "seg_id_2",
        "orient_seg_2",
        "orient_key_seg_2",
    ),
)
"""
Information about a link between two segments.

Attributes
----------
seg_id_1 : str
    Identifier of the first segment.
orient_seg_1 : int
    Orientation of the first segment (+1 or -1).
orient_key_seg_1 : int
    Orientation key for the first segment (often the same as orient_seg_1).
seg_id_2 : str
    Identifier of the second segment.
orient_seg_2 : int
    Orientation of the second segment (+1 or -1).
orient_key_seg_2 : int
    Orientation key for the second segment (often the same as orient_seg_2).
"""


[docs] @dataclass class SubGraph: """ A class representing a subgraph in a genomic analysis pipeline. It handles operations like BED file intersection, GFA walk/link/segment building, and FASTA sequence generation for a specific sample or region. Attributes ---------- bam_path : Path The file path to the BAM file. bed_path : Path The file path to the BED file. logger : logging.Logger The logger instance for logging messages. Defaults to a logger named "GraTools". sample_name : Optional[str] The name of the sample, derived from `bed_path`. Default to None. sample_name_query : Optional[str] The name of the sample being queried. Default to None. chromosome_query : Optional[str] The chromosome name for the query. Default to None. start_query : Optional[int] The start position on the chromosome for the query. Default to None. stop_query : Optional[int] The stop position on the chromosome for the query. Default to None. offset_first : int The offset for the first segment in the query region. Default to 0. offset_last : int The offset for the last segment in the query region. Default to 0. add_start_bases_first_segment : int Additional start bases for the first segment (usage seems specific, consider clarifying). Default to 0. intersect_bed : Optional[BedTool] The BedTool object for intersected BED regions. Default to None. segment_id_set : Set[str] A set storing unique segment IDs encountered during walk building. segment_id_first_query : Optional[str] The ID of the first segment in the query region. Default to None. segment_id_first_strand : Optional[str] The strand ('+' or '-') of the first segment ID in the query. Default to None. segment_id_last_query : Optional[str] The ID of the last segment ID in the query region. Default to None. segment_id_last_strand : Optional[str] The strand ('+' or '-') of the last segment ID in the query. Default to None. segment_id_first : Optional[str] The first segment ID encountered for the current sample (might be same as query). Default to None. segment_id_last : Optional[str] The last segment ID encountered for the current sample (might be same as query). Default to None. works_path : Optional[Path] The working directory path. Defaults to None. merge : Optional[int] Merge distance parameter for BED operations (e.g., `bedtools merge -d`). Default to None. build_fasta_flag : Optional[bool] Flag indicating whether FASTA sequences should be built. Default to False. gfa_walk_list : List[str] A list of GFA walk strings (W lines). Default to an empty list. gfa_link_list : List[str] A list of GFA link strings (L lines). Default to an empty list. gfa_segment_list : List[str] A list of GFA segment strings (S lines). Default to an empty list. dict_segments_samples : defaultdict[str, List[str]] A dictionary mapping segment IDs to a list of sample identifiers. Default to an empty defaultdict. dict_segments_sequence : defaultdict[str, str] A dictionary mapping segment IDs to their sequences. Default to an empty defaultdict. sequences_list : List[SeqRecord] A list of Biopython SeqRecord objects for generated FASTA sequences. Default to an empty list. progress_dict : Optional[Dict] A dictionary to track progress, typically for multi-processing. Default to None. task_id : Optional[TaskID] The task ID for progress tracking with `rich.progress`. Default to None. regions : Optional[List[Dict[str, Any]]] A list of regions (dictionaries with 'chromosome', 'start', 'stop'). Default to None. intersected_results_by_regions : Optional[BedTool] A BedTool object containing combined intersected results for all regions. Default to None. """ bam_path: Path bed_path: Path logger: logging.Logger = field( default_factory=lambda: logging.getLogger("GraTools") ) sample_name: Optional[str] = None sample_name_query: Optional[str] = None chromosome_query: Optional[str] = None start_query: Optional[int] = None stop_query: Optional[int] = None offset_first: int = 0 offset_last: int = 0 add_start_bases_first_segment: int = 0 intersect_bed: Optional[BedTool] = None segment_id_set: Set[str] = field(default_factory=set, repr=False) segment_id_first_query: Optional[str] = None segment_id_first_strand: Optional[str] = None segment_id_last_query: Optional[str] = None segment_id_last_strand: Optional[str] = None segment_id_first: Optional[str] = None # Tracks the first segment ID overall for this SubGraph instance segment_id_last: Optional[str] = None # Tracks the last segment ID overall for this SubGraph instance works_path: Optional[Path] = None merge: Optional[int] = None build_fasta_flag: Optional[bool] = False gfa_walk_list: List[str] = field(default_factory=list, repr=False) gfa_link_list: List[str] = field(default_factory=list, repr=False) gfa_segment_list: List[str] = field(default_factory=list, repr=False) dict_segments_samples: defaultdict = field( default_factory=lambda: defaultdict(list), repr=False ) dict_segments_sequence: defaultdict = field( default_factory=lambda: defaultdict(str), repr=False # Ensure defaultdict has a type ) sequences_list: List[SeqRecord] = field(default_factory=list, repr=False) progress_dict: Optional[Dict] = None task_id: Optional[TaskID] = None # Changed from int to TaskID for rich.progress type hint regions: Optional[List[Dict[str, Any]]] = None intersected_results_by_regions: Optional[BedTool] = None def __post_init__(self) -> None: """Initialize sample_name from bed_path.""" suffixes = "".join(self.bed_path.suffixes) # e.g., ['.bed', '.gz'] for file.bed.gz if ".bed.gz" in suffixes: # For name.bed.gz, self.bed_path.name.replace(".bed.gz", "") self.sample_name = self.bed_path.name.removesuffix(".bed.gz") elif ".bed" in suffixes: # For name.bed, self.bed_path.name.replace(".bed", "") self.sample_name = self.bed_path.name.removesuffix(".bed") else: self.logger.error(f"The file {self.bed_path} is not a recognized BED file type.")
[docs] def compute_intersection(self) -> None: """ Compute the intersection of BED regions for each region in self.regions. Store the results in `self.intersected_results_by_regions`. """ all_bed = None if not self.regions: # Default region if none provided self.regions = [ { "chromosome": self.chromosome_query, "start": self.start_query, "stop": self.stop_query, } ] if not isinstance(self.regions, list) or not all( isinstance(region, dict) and all(k in region for k in ["chromosome", "start", "stop"]) for region in self.regions ): self.logger.error( "Invalid regions provided. Must be a list of dictionaries, " "each with 'chromosome', 'start', and 'stop' keys." ) return # Exit if regions are invalid # Load BED file try: all_bed = BedTool(self.bed_path.as_posix()) except Exception as e: # pybedtools can raise various exceptions self.logger.error(f"Failed to load BED file '{self.bed_path}': {e}") return # Exit if BED file cannot be loaded # 3. Build a single BedTool for all regions, with index in the 4th column lines = [] for idx, reg in enumerate(self.regions): lines.append(f"{reg['chromosome']}\t{reg['start']}\t{reg['stop']}\t{idx}") regions_bed = BedTool("\n".join(lines), from_string=True) # 4. Single intersect call and saveas() try: intersect_all = all_bed.intersect(regions_bed, wa=True, wb=True) tb = intersect_all.saveas() # creates a temporary file on disk self.intersected_results_by_regions = tb if intersect_all.count() <= 0: # It's not an error for an intersection to be empty, but a warning/info might be useful. if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 7, "total": 7, # Mark as complete "description": f"[red]'{self.sample_name}': No intersection found for region, aborted.", } self.logger.error(f"'{self.sample_name}': No intersection found for region {self.regions}") except Exception as e: self.logger.error(f"Error during single intersection: {e}") return
[docs] def build_walks(self) -> None: """Build GFA walks (W lines) and links (L lines) from the intersected BED regions.""" self.logger.info(f"'{self.sample_name}': Building subgraph GFA walks and links") self.segment_id_set = set() # Reset for current build if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 5, "total": 7, "description": f"[cyan]'{self.sample_name}': Building walks and links", } old_region = None # defaultdict for walks: Key= (chrom_name, haplo_idx), Value= dict of fragments # Fragment dict: Key= chrom_fragment_id, Value= data for this fragment walk_dict = defaultdict( lambda: defaultdict( lambda: { "seg_list": [], "start": float("inf"), # Chromosomal start of the walk part "stop": float("-inf"), # Chromosomal end of the walk part } ) ) # Iterate region by region for intersect_region_bedtool in self.intersected_results_by_regions: fields = intersect_region_bedtool # fields = colonnes A + colonnes B (4 cols: chrom,start,stop,index) region_idx = int(fields[-1]) if region_idx != old_region: old_region = region_idx region_start = None chromosome_name = fields.chrom segment_start_on_chr = fields.start # Start of this segment on the chromosome segment_stop_on_chr = fields.stop # End of this segment on the chromosome # GFA segment ID with orientation (e.g., +seg1, -seg2) oriented_segment_id = fields[3] # Fragment identifier on the chromosome (e.g., path fragment from original BED) chromosome_fragment_id = fields[4] haplotype_index = fields[5] if region_start is None: region_start = segment_start_on_chr # Add segment ID (without orientation) to the set of all segments self.segment_id_set.add(str(oriented_segment_id[1:])) # Store first/last segment info if this is the query sample if self.sample_name == self.sample_name_query: if not self.segment_id_first_query: # First segment encountered for the query sample self.segment_id_first_query = oriented_segment_id[1:] self.segment_id_first_strand = oriented_segment_id[0] self.segment_id_first = self.segment_id_first_query # Also set the general first segment # Always update last segment for query sample to track the true end self.segment_id_last_query = oriented_segment_id[1:] self.segment_id_last_strand = oriented_segment_id[0] self.segment_id_last = self.segment_id_last_query # Also set the general last segment # Group segments by (chromosome, haplotype, fragment_id) key = f"{chromosome_name}:{haplotype_index}:{region_idx}" # Walk identifier part fragment_data = walk_dict[key][chromosome_fragment_id] fragment_data["seg_list"].append(oriented_segment_id) fragment_data["start"] = region_start fragment_data["stop"] = segment_stop_on_chr # Process collected walks for chr_name_haplo_key, fragments_dict in walk_dict.items(): chr_name, haplotype_idx, _ = chr_name_haplo_key.split(":") for fragment_id, data in fragments_dict.items(): # These are the actual genomic coordinates covered by this piece of the walk walk_genomic_start = data["start"] walk_genomic_stop = data["stop"] ordered_segments = data["seg_list"] # Segments in order of appearance for this walk fragment # Create GFA W line components walk_line_parts = [ "W", self.sample_name, haplotype_idx, chr_name, str(walk_genomic_start), str(walk_genomic_stop), ] # Build L lines from these ordered segments self._build_links(ordered_segments) # Pass the list of oriented segment IDs # Create GFA path string (e.g., >seg1<seg2>seg3) # The replace calls are for GFA spec compliance (> for forward, < for reverse) gfa_path_string = "".join(s.replace("+", ">").replace("-", "<") for s in ordered_segments) walk_line_parts.append(gfa_path_string) # TODO: save to file for concat after on gratools self.gfa_walk_list.append("\t".join(walk_line_parts))
def _build_links(self, segments_order: List[str]) -> None: """ Build GFA link lines (L lines) from an ordered list of oriented segment IDs. Parameters ---------- segments_order : List[str] Ordered list of GFA segment IDs, each prefixed with orientation (e.g., ["+seg1", "-seg2", "+seg3"]). """ if len(segments_order) < 2: # Not enough segments to form a link return for i in range(len(segments_order) - 1): seg1_oriented = segments_order[i] seg2_oriented = segments_order[i + 1] orient1_char = seg1_oriented[0] # '+' or '-' seg1_id = seg1_oriented[1:] orient2_char = seg2_oriented[0] # '+' or '-' seg2_id = seg2_oriented[1:] # Create L line: L <source_seg> <source_orient> <dest_seg> <dest_orient> <overlap_CIGAR> # Assuming 0M overlap for simplicity, as is common in many GFA1.1 outputs. txt_link = f"L\t{seg1_id}\t{orient1_char}\t{seg2_id}\t{orient2_char}\t0M" self.gfa_link_list.append(txt_link)
[docs] def build_segments(self) -> None: """Build GFA segments (S lines) from the BAM file for segments in `self.segment_id_set`.""" if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 6, "total": 7, "description": f"[cyan]'{self.sample_name}': Building segments", } self.logger.info(f"'{self.sample_name}': Building subgraph GFA segments") if not self.segment_id_set: self.logger.warning(f"'{self.sample_name}': No segment IDs for segment lines building. Skipping.") self.gfa_segment_list = [] self.dict_segments_samples = defaultdict(list) self.dict_segments_sequence = defaultdict(str) return bam_handler = GratoolsBam(bam_path=self.bam_path, logger=self.logger) # Pass logger ( self.gfa_segment_list, self.dict_segments_samples, self.dict_segments_sequence, ) = bam_handler.build_segments(list_segments=self.segment_id_set)
[docs] def filter_bed_with_awk(self) -> BedTool: """ Filter a BED file using an awk command line to extract only lines containing an ID of interest. """ self.logger.info(f"'{self.sample_name}': Starting awk filtering for {len(self.segment_id_set)} IDs.") segment_id_file = self.works_path / f"{self.sample_name}_segment_ids.txt" filtered_bed_path = self.works_path / f"{self.sample_name}_filtered.bed" # Save IDs to a file with open(segment_id_file, "w") as f: for seg_id in self.segment_id_set: f.write(f"{seg_id}\n") awk_script = f'''awk 'NR==FNR {{ids[$1]; next}} substr($4, 2) in ids' "{segment_id_file}" "{self.bed_path}" > "{filtered_bed_path}"''' with open(filtered_bed_path, "w") as out_f: subprocess.run(awk_script, shell=True, check=True, stdout=out_f) self.logger.info(f"'{self.sample_name}': Filtered file with awk written to: {filtered_bed_path}") return BedTool(filtered_bed_path), filtered_bed_path, segment_id_file
[docs] def get_chr_pos(self, progress_dict: Optional[Dict], task_id: Optional[TaskID]) -> None: """ Identify chromosomal regions corresponding to segments in `self.segment_id_set`, then compute intersections, and finally build walks and segments for these regions. Parameters ---------- progress_dict : Optional[Dict] A dictionary to track progress for multiprocessing. task_id : Optional[TaskID] The task ID for `rich.progress` tracking. """ self.progress_dict = progress_dict self.task_id = task_id if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 2, "total": 7, "description": f"[cyan]'{self.sample_name}': Searching corresponding chr pos", } if not self.segment_id_set: if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 7, "total": 7, # Mark as complete "description": f"[red]'{self.sample_name}': the segment_id_set is empty, aborted.", } self.logger.error( f"'{self.sample_name}': segment_id_set is empty. " "Cannot determine chromosome positions. Ensure that build_walks (or equivalent setting segment_id_set) was called." ) return filtered_bed_all, filtered_bed_path, segment_id_file = self.filter_bed_with_awk() if filtered_bed_all.file_type == 'empty': if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 7, "total": 7, "description": f"[red]'{self.sample_name}': No segment in BED.", } self.logger.error( f"'{self.sample_name}': No matching segment found in BED file '{self.bed_path}' " f"for segment IDs: {list(self.segment_id_set)[:5]}..." # Log a few example IDs ) self.progress_dict[self.task_id] = { "progress": 3, "total": 7, # Mark as complete "description": f"'{self.sample_name}': BED filtering successful, proceeding to merge..", } if self.works_path: file_output_all = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_all.csv" file_output_collapse = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_collapsed_{self.merge}.csv" # Convert to DataFrame after merging (without -d, i.e., only abutting features) merged_all_bed = filtered_bed_all.merge(c=[4, 5, 6], o=["collapse", "distinct", "distinct"]) # bedtools merge columns: name, score, strand merged_all_bed.saveas(file_output_all) merged_collapse_bed = merged_all_bed.merge(d=self.merge, c=[4, 5, 6], o=["collapse", "distinct", "distinct"]) merged_collapse_bed.saveas(file_output_collapse) # Define regions for subsequent intersection based on the collapsed dataframe self.logger.info(f"'{self.sample_name}': compute regions base on merge dataframe") self.regions = [ {"chromosome": feature.chrom, "start": feature.start, "stop": feature.stop} for feature in merged_collapse_bed ] self.logger.debug( f"'{self.sample_name}': Found {len(self.regions)} regions for processing. Example: {self.regions[0] if self.regions else 'None'}." ) if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 4, "total": 7, "description": f"[cyan]'{self.sample_name}': Found corresponding chr pos", } if Path(filtered_bed_path).exists(): filtered_bed_path.unlink() if Path(segment_id_file).exists(): segment_id_file.unlink() # Process regions: compute intersection, build walks, build segments self.compute_intersection() self.build_walks() # build_walks populates self.segment_id_set self.build_segments() # build_segments uses self.segment_id_set # Log completion if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 7, "total": 7, "description": f"[green]'{self.sample_name}': End processing", }
[docs] def build_fasta(self) -> None: """ Build FASTA sequences from the GFA walks (`self.gfa_walk_list`). This method recovers FASTA sequences by processing each walk, extracting the constitutive segments, and applying specific offsets if the current sample is the query sample (`self.sample_name == self.sample_name_query`). The offsets trim segment sequences at the beginning or end of the walk to match the precise query region boundaries (`self.start_query`, `self.stop_query`). The resulting sequences are stored as `SeqRecord` objects in `self.sequences_list`. """ num_walks = len(self.gfa_walk_list) self.logger.info(f"'{self.sample_name}': Building FASTA for {num_walks} walks.") if not self.dict_segments_sequence: self.logger.warning(f"'{self.sample_name}': Segment sequences dictionary is empty. Cannot build FASTA.") return if not self.gfa_walk_list: self.logger.info(f"'{self.sample_name}': No GFA walks to process for FASTA generation.") return newly_created_records: List[SeqRecord] = [] # Store records for this call for walk_line in self.gfa_walk_list: try: ( _w_char, # 'W' character current_walk_sample_name, # Sample name from the W line haplotype_index, chromosome_name, walk_genomic_start_str, walk_genomic_stop_str, gfa_path_specification, # e.g., ">seg1<seg2" or "seg1+seg2-" ) = walk_line.split("\t") except ValueError as e: self.logger.error( f"'{self.sample_name}': Malformed walk line during FASTA build: '{walk_line}'. Error: {e}", exc_info=True) continue walk_genomic_start = int(walk_genomic_start_str) walk_genomic_stop = int(walk_genomic_stop_str) # Initialize effective coordinates for this walk's sequence. # These will be adjusted by offsets for the query sample. effective_walk_chr_start = walk_genomic_start effective_walk_chr_stop = walk_genomic_stop # Calculate offsets if this is the query sample. # These offsets represent how much to trim from the start/end of the *entire walk's sequence* # if the walk extends beyond the query region. if self.sample_name == self.sample_name_query: if self.start_query is not None and self.stop_query is not None: # Ensure query bounds are set self.offset_first = self.start_query - walk_genomic_start self.offset_last = walk_genomic_stop - self.stop_query else: self.logger.warning( f"'{self.sample_name}': Query sample set, but start/stop_query not fully set. Offsets may be incorrect.") # Parse GFA path string (e.g., ">s1<s2" or "s1+s2-") into oriented segments # GFA1 W line path: >s1>s2<s3... or s1+s2-s3+... oriented_segments_in_path: List[Tuple[str, str]] = [] # List of (orientation_char, segment_id) if ">" in gfa_path_specification[0] or "<" in gfa_path_specification[0]: # Format like >s1<s2 matches = re.findall(r"([<>])([^<>]+)", gfa_path_specification) for orient_map_char, seg_id_str in matches: oriented_segments_in_path.append(("+" if orient_map_char == ">" else "-", seg_id_str)) else: # Format like s1+s2- (less common in W lines but GFA spec allows segment names with +/-) matches = re.findall(r"(.+?)([+-])", gfa_path_specification) for seg_id_str, orient_char in matches: oriented_segments_in_path.append((orient_char, seg_id_str)) if not oriented_segments_in_path: self.logger.warning( f"'{self.sample_name}': Could not parse path specification '{gfa_path_specification}' in walk line: '{walk_line}'. Skipping this walk for FASTA building.") continue num_segments_in_walk = len(oriented_segments_in_path) # List to hold Bio.Seq fragments for constructing the full sequence of this walk current_walk_seq_fragments: List[str] = [] self.logger.debug( f"'{self.sample_name}': Processing walk on {chromosome_name} ({walk_genomic_start}-{walk_genomic_stop}) " f"with {num_segments_in_walk} segments for FASTA." ) for segment_idx, (segment_orient_char, segment_id) in enumerate(oriented_segments_in_path): if segment_id not in self.dict_segments_sequence: self.logger.warning( f"'{self.sample_name}': Sequence for segment '{segment_id}' in walk not found. Skipping this segment.") # Potentially add a placeholder like 'N's if a continuous sequence is critical, # or simply skip, which might create a shorter/disjointed sequence. continue current_segment_seq = self.dict_segments_sequence[segment_id] if segment_orient_char == "-": current_segment_seq = reverse_complement_string(current_segment_seq) # Apply complex offsetting logic from the original code IF this is the query sample if self.sample_name == self.sample_name_query: # These checks need self.segment_id_first_query, self.segment_id_first_strand, etc. # to be correctly set (usually when the query SubGraph itself was processed). # And self.offset_first / self.offset_last calculated above. # Check if current segment is the designated "first query segment" if segment_id == self.segment_id_first_query: if segment_idx == 0: # Is it the *first segment of this walk*? # And does its orientation match the expected start orientation? if segment_orient_char == self.segment_id_first_strand: if self.offset_first > 0: # Positive offset means query starts *after* walk's genomic start if self.offset_first < len(current_segment_seq): current_segment_seq = current_segment_seq[self.offset_first:] effective_walk_chr_start = walk_genomic_start + self.offset_first else: # Offset is >= segment length, segment is fully trimmed self.logger.debug( f"Offset_first ({self.offset_first}bp) consumes entire first segment '{segment_id}' ({len(current_segment_seq)}bp).") current_segment_seq = "" # else: offset_first <= 0, means query starts at or before walk start, no trim from beginning needed. elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*? # This case is for when the *first query segment* also happens to be the *last segment of the walk*, # and its orientation is opposite to the expected start orientation. # This is a very specific scenario, implying the walk might be very short or circular in a way. if segment_orient_char != self.segment_id_first_strand: # Here, offset_first is used to trim from the *end* of the segment. # This implies offset_first was calculated based on the query *start*, but is applied to the *end* # of this segment if it's the first_query_segment but found at the end of the walk in reverse. if self.offset_first > 0: # If an offset from the start of the query was calculated if self.offset_first < len(current_segment_seq): current_segment_seq = current_segment_seq[ :-self.offset_first] # Trim from end effective_walk_chr_stop = walk_genomic_stop - self.offset_first else: current_segment_seq = "" # Check if current segment is the designated "last query segment" if segment_id == self.segment_id_last_query: if segment_idx == 0: # Is it the *first segment of this walk*? # This case is for when the *last query segment* is also the *first segment of the walk*, # and its orientation is opposite to the expected end orientation. if segment_orient_char != self.segment_id_last_strand: # Here, offset_last is used to trim from the *beginning* of the segment. # offset_last was calculated based on query *end*. if self.offset_last > 0: # Positive offset_last means query ends *before* walk's genomic end if self.offset_last < len(current_segment_seq): current_segment_seq = current_segment_seq[ self.offset_last:] # Trim from beginning effective_walk_chr_start = walk_genomic_start + self.offset_last else: current_segment_seq = "" elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*? # And does its orientation match the expected end orientation? if segment_orient_char == self.segment_id_last_strand: if self.offset_last > 0: if self.offset_last < len(current_segment_seq): current_segment_seq = current_segment_seq[:-self.offset_last] effective_walk_chr_stop = walk_genomic_stop - self.offset_last else: self.logger.error( f"Offset_last ({self.offset_last}bp) consumes entire last segment '{segment_id}' ({len(current_segment_seq)}bp).") current_segment_seq = "" if len(current_segment_seq) > 0: current_walk_seq_fragments.append(current_segment_seq) # After processing all segments for this walk if not current_walk_seq_fragments: self.logger.debug( f"'{self.sample_name}': No sequence fragment generated for walk on {chromosome_name} " f"(ID: {current_walk_sample_name}-{haplotype_index}). Skipping FASTA record for this walk.") continue final_walk_sequence = Seq("".join(current_walk_seq_fragments)) if len(final_walk_sequence) == 0: self.logger.debug( f"'{self.sample_name}': Resulting sequence for walk on {chromosome_name} is empty. Skipping FASTA.") continue # Construct FASTA record ID and description fasta_seq_id: str fasta_description: str if self.sample_name == self.sample_name_query: # For the query sample, use precise query coordinates in the ID if available # Ensure query parameters are fully defined for this formatting q_sample = self.sample_name_query or "query_sample" q_chrom = self.chromosome_query or chromosome_name # Fallback to walk's chrom if query_chrom not set # Use effective_walk_chr_start/stop which were adjusted by offsets fasta_seq_id = f"{q_sample}-{q_chrom}:{effective_walk_chr_start}-{effective_walk_chr_stop}" fasta_description = f"" else: # For other (target/non-query) samples fasta_seq_id = f"{current_walk_sample_name}-{chromosome_name}:{effective_walk_chr_start}-{effective_walk_chr_stop}" # Description includes length of this sample's sequence and the query region it corresponds to seq_actual_len = effective_walk_chr_stop - effective_walk_chr_start query_region_str = "N/A" if self.sample_name_query and self.chromosome_query and \ self.start_query is not None and self.stop_query is not None: query_region_str = f"{self.sample_name_query}-{self.chromosome_query}:{self.start_query}-{self.stop_query}" fasta_description = f"SEQ_LEN={seq_actual_len}|QUERY={query_region_str}" # Ensure ID is suitable for FASTA (e.g., no problematic spaces at end) fasta_seq_id = fasta_seq_id.strip() seq_record = SeqRecord( final_walk_sequence, id=fasta_seq_id, name=fasta_seq_id, # Often same as id description=fasta_description ) newly_created_records.append(seq_record) # After processing all walks self.sequences_list.extend(newly_created_records) # Add all new records to the main list self.logger.info( f"'{self.sample_name}': Finished FASTA building. Added {len(newly_created_records)} new records. " f"Total records in sequences_list: {len(self.sequences_list)}.")
[docs] class AsyncGfaDatabase: """ Manage an asynchronous SQLite database for storing and querying GFA link data. It uses `aiosqlite` for non-blocking operations within an asyncio event loop and serializes writes via an internal FIFO queue to prevent SQLite lock contention. Attributes ---------- db_file : Path Path to the SQLite database file. timeout : float Maximum timeout (in seconds) for SQLite lock acquisition. logger : logging.Logger Logger instance. _conn : Optional[aiosqlite.Connection] Shared SQLite connection (or None if not connected). _write_queue : asyncio.Queue Asynchronous queue for batches of links to be inserted. Max size 100. _sql_task : Optional[asyncio.Task] Background task consuming the queue and writing to the database. _shutdown : bool Flag to signal shutdown to the writer task. """ def __init__(self, db_file: Path, timeout: float = 30.0): """ Initialize the instance without opening the connection. Scheme is created upon the first call to `connect()`. Args: db_file (Path): Path to the SQLite file to be used as backend. timeout (float, optional): Maximum wait time for SQLite locks (default 30.0s). """ self.db_file: Path = db_file self.timeout: float = timeout # In seconds self._shutdown: bool = False self._conn: Optional[aiosqlite.Connection] = None self._write_queue: asyncio.Queue[Optional[List[Tuple[str, int, int, str, int, int]]]] = asyncio.Queue( maxsize=1000) self._sql_task: Optional[asyncio.Task] = None self.logger = logging.getLogger("GraTools") # More specific logger name
[docs] async def connect(self) -> None: """ Connect to the SQLite db (if not already connected), configure PRAGMA settings, create the 'links' table schema (if it doesn't exist), and start the SQL writer task. This method is idempotent: if already connected, it does nothing. """ if self._conn: return self.logger.debug(f"Connecting to the SQLite database: {self.db_file.name}") # Open the single connection self._conn = await aiosqlite.connect( self.db_file.as_posix(), timeout=self.timeout # timeout here is for connection ) # Performance optimizations and configuration await self._conn.execute("PRAGMA synchronous = OFF;") # Faster, but higher risk of corruption on crash await self._conn.execute("PRAGMA journal_mode = WAL;") # Write-Ahead Logging for better concurrency await self._conn.execute( f"PRAGMA busy_timeout = {int(self.timeout * 1000)};") # busy_timeout is in milliseconds # Create schema await self._conn.execute(""" CREATE TABLE IF NOT EXISTS links ( seg_id_1 TEXT, orient_seg_1 INTEGER, -- +1 for '+', -1 for '-' orient_key_seg_1 INTEGER, -- Often same as orient_seg_1, purpose might be specific seg_id_2 TEXT, orient_seg_2 INTEGER, orient_key_seg_2 INTEGER ); """) await self._conn.commit() # Start the background insertion task if self._sql_task is None or self._sql_task.done(): self._sql_task = asyncio.create_task(self._sql_writer()) self.logger.debug("AsyncGfaDatabase connected and writer task started/confirmed.")
async def _sql_writer(self) -> None: """ Background task that continuously fetches link batches from the queue and inserts them into the database. It handles shutdown signals and manages SQLite transactions for batch inserts. """ assert self._conn is not None, "Connection must be initialized before writer starts" self.logger.debug("SQL writer task started.") processed_data_batches = 0 # Counter for actual data batches while True: # The exit condition is handled by internal breaks batch = None item_was_retrieved = False try: # Wait for an item from the queue with a timeout # The timeout allows periodic checking of _shutdown batch = await asyncio.wait_for(self._write_queue.get(), timeout=0.5) item_was_retrieved = True if batch is None: # Sentinel to stop the writer self.logger.debug("SQL writer received end. Preparing to stop.") break # Exit the main while loop if not batch: # An empty batch (empty list), not the sentinel # task_done will be called in finally continue # Go to the next iteration to get another batch # This is a valid data batch processed_data_batches += 1 self.logger.debug( f"SQL writer processing data batch {processed_data_batches} of size {len(batch)}. " f"Queue size: {self._write_queue.qsize()}" ) # Database insertion await self._conn.executemany( """INSERT INTO links ( seg_id_1, orient_seg_1, orient_key_seg_1, seg_id_2, orient_seg_2, orient_key_seg_2 ) VALUES (?, ?, ?, ?, ?, ?);""", batch, ) await self._conn.commit() # self.logger.debug(f"SQL writer successfully committed data batch {processed_data_batches}.") except asyncio.TimeoutError: # Timeout while waiting for an item from the queue if self._shutdown and self._write_queue.empty(): self.logger.info("SQL writer: Timeout on get() with _shutdown True and queue empty. Stopping.") break # Exit the main while loop continue except sqlite3.Error as e: self.logger.error( f"SQLite error during processing of batch (approx. #{processed_data_batches + 1}): {e}", exc_info=True) if self._conn.in_transaction: try: await self._conn.rollback() self.logger.warning("SQL writer rolled back transaction due to SQLite error.") except sqlite3.Error as rb_err: self.logger.error(f"Error during SQLite rollback: {rb_err}") # Continue to try and process other batches, the current batch is lost without re-queuing. except Exception as e: self.logger.error(f"Unexpected error in SQL writer (approx. batch #{processed_data_batches + 1}): {e}", exc_info=True) # Handle transaction if necessary, as for sqlite3.Error finally: if item_was_retrieved: self._write_queue.task_done() # if batch is not None and batch: # If it was an actual data batch # self.logger.debug(f"SQL writer marked data batch {processed_data_batches} as done.") # No special log for an empty batch after task_done self.logger.debug("SQL writer task has stopped.")
[docs] async def create_indexes(self) -> None: """ Create indexes on seg_id_1 and seg_id_2 to accelerate queries. Should ideally be called after all data insertions are complete. """ if not self._conn: await self.connect() self.logger.info("Creating indexes on links table.") try: await self._conn.execute( "CREATE INDEX IF NOT EXISTS idx_seg_id_1 ON links(seg_id_1);" ) await self._conn.execute( "CREATE INDEX IF NOT EXISTS idx_seg_id_2 ON links(seg_id_2);" ) await self._conn.commit() except sqlite3.Error as e: self.logger.error(f"Failed to create indexes: {e}", exc_info=True)
[docs] async def find_children_and_grandchildren( self, node_id: str ) -> Dict[str, List[str]]: """ Find direct successors (children) and second-degree successors (grandchildren) of a segment. A child is `seg_id_2` where `node_id` is `seg_id_1`. A grandchild is a child of a child. Args: node_id : The starting segment ID. Returns: A dictionary: {"children": [IDs], "grandchildren": [IDs]}. """ if not self._conn: await self.connect() self.logger.debug(f"Finding children and grandchildren for node: {node_id}") children: List[str] = [] grandchildren: List[str] = [] try: # Direct children async with self._conn.execute( "SELECT seg_id_2 FROM links WHERE seg_id_1 = ?;", (node_id,) ) as cur_children: children = [r[0] for r in await cur_children.fetchall()] # r is a tuple e.g. ('seg_X',) self.logger.debug(f"Found {len(children)} children for {node_id}.") # Grandchildren (children of children) if children: # Create placeholders for the IN clause: (?, ?, ...) placeholders = ",".join("?" for _ in children) sql_grandchildren = f"SELECT seg_id_2 FROM links WHERE seg_id_1 IN ({placeholders});" async with self._conn.execute(sql_grandchildren, children) as cur_gc: grandchildren_tuples = await cur_gc.fetchall() # Deduplicate grandchildren, as multiple children might lead to the same grandchild. grandchildren = sorted(list(set(r[0] for r in grandchildren_tuples))) self.logger.debug( f"Found {len(grandchildren)} unique grandchildren for {node_id} (from {len(grandchildren_tuples)} raw)." ) except sqlite3.Error as e: self.logger.error(f"Error finding children/grandchildren for {node_id}: {e}", exc_info=True) # Return lists as they are up to the point of error. return {"children": children, "grandchildren": grandchildren}
[docs] async def close(self) -> None: """ Properly shut down the database: - Signal the SQL writer task to stop. - Wait for the writer task to finish processing its queue (with timeout). - Cancel the task if it doesn't finish in time. - Create indexes (important to do this after all writes). - Close the SQLite connection. """ self.logger.debug("Initiating AsyncGfaDatabase shutdown...") self._shutdown = True # Signal writer to stop after draining queue if self._write_queue: self.logger.debug("Putting sentinel on SQL write queue.") await self._write_queue.put(None) # Sentinel to stop the writer loop self.logger.debug("Waiting for SQL queue to drain...") try: # Wait for all items in queue to be processed. # Timeout for queue.join() should be generous if large batches or slow I/O expected. await self._write_queue.join() self.logger.debug("SQL queue drained successfully (all items processed).") except asyncio.TimeoutError: self.logger.warning("Timeout waiting for SQL queue to drain during close.") except Exception as e: self.logger.error(f"Error waiting for SQL queue to drain: {e}", exc_info=True) if self._sql_task and not self._sql_task.done(): self.logger.debug("SQL writer task still running, attempting to cancel.") self._sql_task.cancel() try: await self._sql_task # Wait for task to acknowledge cancellation self.logger.debug("SQL writer task cancelled successfully.") except asyncio.CancelledError: self.logger.info("SQL writer task was cancelled as expected during close.") except Exception as e: self.logger.error(f"Error awaiting SQL writer task cancellation: {e}", exc_info=True) elif self._sql_task and self._sql_task.done(): self.logger.debug("SQL writer task was already done.") else: self.logger.debug("No SQL writer task to shut down or task was never started.") # Create indexes after all data is written and writer task is stopped if self._conn: # Only if connection was ever established await self.create_indexes() try: await self._conn.close() except Exception as e: # Catch potential errors during aiosqlite close self.logger.error( f"Error closing SQLite connection: {e}", exc_info=True ) finally: self._conn = None # Ensure _conn is None after attempting to close
class AsyncBedWriter: """ Asynchronous BED file writer using `aiofiles`. It buffers lines per sample and writes them in batches to separate BED files. Attributes ---------- bed_dir : Path Output directory for .bed files. batch_size : int Number of lines to buffer per sample before an automatic flush. progress : Optional[Progress] Rich Progress instance for displaying progress (if provided). logger : logging.Logger Logger instance. _queue : asyncio.Queue[Tuple[str, List[str]]] Internal queue for (sample_name, lines_to_write) tuples. Max size 100. _shutdown : bool Flag to signal shutdown to the writer loop. _task : Optional[asyncio.Task] The background asyncio task running the `_writer_loop`. """ def __init__(self, bed_dir: Path, batch_size: int = 10000, progress: Optional[Progress] = None): self.bed_dir: Path = bed_dir self.batch_size: int = batch_size self._queue: asyncio.Queue[Tuple[str, List[str]]] = asyncio.Queue(maxsize=1000) self._shutdown: bool = False self._task: Optional[asyncio.Task] = None self.progress: Optional[Progress] = progress # Can be None self.logger = logging.getLogger("GraTools") def start(self) -> None: """Start the writer loop as a background task if not already running.""" if self._task is None or self._task.done(): self._shutdown = False # Reset shutdown flag if restarting self.logger.debug("Starting AsyncBedWriter loop task.") self._task = asyncio.create_task(self._writer_loop()) self.logger.debug("AsyncBedWriter task launched.") else: self.logger.debug("AsyncBedWriter task is already running.") def enqueue(self, sample: str, lines: List[str]) -> None: """ Add a sample and its corresponding lines to the write queue. Use `put_nowait` assuming the queue rarely fills; consider `await self._queue.put()` if backpressure to the producer is acceptable when the queue is full. Args: sample (str): The sample name (used for filename). lines (List[str]): A list of strings (lines) to write to the BED file. Each string should include its own newline character if needed. """ if not lines: self.logger.debug(f"AsyncBedWriter: enqueue called with empty lines for sample '{sample}'. Skipping.") return try: self.logger.debug( f"BED writer enqueuing {len(lines)} lines for sample '{sample}'. Current queue size: {self._queue.qsize()}" ) self._queue.put_nowait((sample, lines)) # Can raise asyncio.QueueFull if maxsize is hit except asyncio.QueueFull: self.logger.error( f"AsyncBedWriter queue is full. Failed to enqueue {len(lines)} lines for sample '{sample}'. " "Consider increasing queue maxsize or slowing down producers." ) # OPTIONAL: Implement retry logic or a different strategy for full queue. async def _writer_loop(self) -> None: """ Consume the queue and writes lines to sample-specific BED files in batches. Manage buffers for each sample. """ self.logger.debug("BED writer loop started.") buffers: Dict[str, List[str]] = defaultdict(list) processed_items_count = 0 # Ensure the output directory exists try: self.bed_dir.mkdir(parents=True, exist_ok=True) except OSError as e: self.logger.error(f"Failed to create BED output directory {self.bed_dir}: {e}. Writer loop cannot proceed.") return while not self._shutdown or not self._queue.empty(): sample_name_for_log = "" try: # Wait for an item, with a timeout to allow checking _shutdown flag current_sample, new_lines = await asyncio.wait_for(self._queue.get(), timeout=0.5) # Increased timeout sample_name_for_log = current_sample # For logging in except/finally processed_items_count += 1 self.logger.debug( f"BED writer received item {processed_items_count} for sample '{current_sample}' " f"with {len(new_lines)} lines. Queue size: {self._queue.qsize()}" ) buffers[current_sample].extend(new_lines) # Flush buffer for this sample if it's full if len(buffers[current_sample]) >= self.batch_size: self.logger.debug( f"BED writer flushing BED buffer for sample '{current_sample}' (size: {len(buffers[current_sample])})." ) await self._flush_sample(current_sample, buffers[current_sample]) # _flush_sample clears the buffer passed # buffers[current_sample] is now empty except asyncio.TimeoutError: # No item from queue within timeout, check shutdown status if self._shutdown and self._queue.empty(): break # Continue polling if not shutting down or queue not empty continue except asyncio.CancelledError: self.logger.info("BED writer loop was cancelled.") # Perform a final flush before exiting due to cancellation await self._flush_all(buffers) raise # Re-raise CancelledError to signal cancellation to caller except Exception as e: self.logger.error(f"Error in BED writer loop processing item for '{sample_name_for_log}': {e}", exc_info=True) finally: # This check is crucial: task_done should only be called if queue.get() succeeded. # If TimeoutError or CancelledError occurred in wait_for, get() didn't return. if 'current_sample' in locals(): # Indicates get() succeeded self._queue.task_done() # self.logger.debug(f"BED writer marked item for sample '{current_sample}' as done.") del current_sample # Clean up to avoid using stale var if next iteration errors early # Periodically flush all buffers if shutdown is requested (helps drain faster) if self._shutdown and not self._queue.empty(): await self._flush_all(buffers) # Flush all to expedite if shutting down # Final flush for any remaining data in buffers after loop finishes self.logger.debug("BED writer loop finished. Performing final flush of all buffers.") await self._flush_all(buffers) self.logger.debug("BED writer loop has stopped.") async def _flush_sample(self, sample: str, buffer_lines: List[str]) -> None: """Write the content of `buffer_lines` to the BED file for `sample` and clear `buffer_lines`.""" if not buffer_lines: self.logger.debug(f"No lines to flush for sample '{sample}'.") return path = self.bed_dir / f"{sample}.bed" # Parent directory creation is handled in _writer_loop start now. try: async with aiofiles.open(path, mode="a", encoding="utf-8") as f: # Assuming lines in buffer_lines already have newline characters. # If not, add them: await f.write("".join(line + "\n" for line in buffer_lines)) await f.writelines(buffer_lines) # self.logger.debug(f"Successfully flushed {len(buffer_lines)} lines to '{path.name}'.") buffer_lines.clear() # Clear the buffer that was passed, after successful write except IOError as e: # More specific exception for file operations self.logger.error( f"IOError flushing sample '{sample}' to '{path}': {e}", exc_info=True ) except Exception as e: self.logger.error( f"Unexpected error flushing sample '{sample}' to '{path}': {e}", exc_info=True ) async def _flush_all(self, buffers: Dict[str, List[str]]) -> None: """Write all non-empty buffers to their respective files.""" self.logger.debug(f"Performing flush of all remaining buffers ({len(buffers)} total).") # Iterate over a copy of items if modifying dict, but here we pass buffer to _flush_sample which clears it. for sample_name, buf_list in list(buffers.items()): # list() for safe iteration if _flush_sample modified dict if buf_list: # Only flush if buffer has content await self._flush_sample(sample_name, buf_list) if not buf_list and sample_name in buffers: # If buffer became empty (cleared by _flush_sample) del buffers[sample_name] # Clean up empty entries from the main buffers dict self.logger.debug("Flush of all buffers complete.") async def shutdown(self) -> None: """ Signal the writer loop to shut down and wait for it to complete. Ensure all pending data is flushed. """ self._shutdown = True if self._task is not None and not self._task.done(): self.logger.debug("Waiting for BED writer task to complete...") try: # The writer loop itself handles queue draining on shutdown signal. # We just need to wait for the task to finish. await asyncio.wait_for(self._task, timeout=None) self.logger.debug("BED writer task completed successfully after shutdown signal.") except asyncio.TimeoutError: self.logger.warning( "Timeout waiting for BED writer task to complete during shutdown. Attempting cancellation.") self._task.cancel() try: await self._task except asyncio.CancelledError: self.logger.info("BED writer task was cancelled during shutdown due to timeout.") except Exception as e_inner: self.logger.error(f"Exception after cancelling BED writer task: {e_inner}", exc_info=True) except asyncio.CancelledError: # If shutdown itself was called from a cancelled context self.logger.info("BED writer task was externally cancelled during its shutdown process.") # May need to re-ensure final flush if task was hard-cancelled. except Exception as e: self.logger.error( f"Exception while waiting for BED writer task during shutdown: {e}", exc_info=True ) elif self._task and self._task.done(): self.logger.debug("BED writer task was already done prior to shutdown call.") else: self.logger.debug( "No BED writer task to shut down (was never started or already None)." )
[docs] @dataclass class GFA: """ Manage parsing of a GFA (Graphical Fragment Assembly) file, compute statistics, and generate related files such as BAM-containing segments files and BED files for path per sample. Fill an asynchronous database for links and asynchronous BED writer. Attributes ---------- gfa_path : Path Path to the input GFA file (can be .gfa or .gfa.gz). threads : int, optional Number of threads for operations like BAM file processing. Default is 1. logger : logging.Logger Logger object. Default is a logger named "GraTools". gfa_name : Optional[str] Name of the GFA file derived from `gfa_path` (without extensions). Auto-initialized. version : Optional[str] GFA version extracted from the header (e.g., "1.0"). Auto-initialized. header_gfa : List[str] List of header lines (H lines) from the GFA file. Auto-initialized. sample_reference : Optional[str] Reference sample name, potentially from GFA header (RS tag). Auto-initialized. bam_segments_file : Optional[Path] Path to the BAM file where segments (S lines) will be written. Auto-initialized. dict_samples_chrom : defaultdict[str, OrderedDict[str, List[str]]] Map sample names to an OrderedDict of chromosome names, which in turn map to a list of "start\\tstop" fragment strings derived from Walk (W) lines. Auto-initialized. dict_segments_size : defaultdict[str, int] Map segment IDs and their length (in base pairs). Auto-initialized. dict_segments_samples : defaultdict[str, List[str]] Map segment IDs to a list of sample identifiers ("sample;chromosome;haplotype") that contain the segment. Auto-initialized. dict_samples_bed : defaultdict[str, OrderedDict[str, Path]] (Note: This attribute is not directly populated by the current parsing logic. It was likely intended to track the paths of generated BED files. The `AsyncBedWriter` internally manages these paths. This attribute might be redundant or used for post-processing tracking). Auto-initialized. works_path : Optional[Path] Path to the working directory (e.g., ".../{gfa_name}_GraTools_INDEX"). Auto-initialized. bed_path : Optional[Path] Path to the subdirectory for BED files within `works_path`. Auto-initialized. bam_path : Optional[Path] Path to the subdirectory for BAM files within `works_path`. Auto-initialized. found_minigraph : bool Flag indicating if a sample named 'MINIGRAPH' (case-insensitive) was found in Walk lines. Defaults to False. index_links : bool, optional If True, GFA links (L lines) are stored in an SQLite database. Defaults to True. db_links : Optional[AsyncGfaDatabase] Asynchronous database handler for GFA links. Auto-initialized if `index_links` is True. segment_count : int Total number of segments (S lines) processed. Defaults to 0. total_segment_length : int Sum of lengths of all segments. Defaults to 0. link_count : int Total number of links (L lines) processed. Defaults to 0. degrees : defaultdict[str, int] Maps segment IDs to their degree (number of links connected). Defaults to an empty defaultdict. walks_count : int Total number of walks (W lines) processed. Defaults to 0. max_walk_rank : int Maximum number of segments in any single walk. Defaults to 0. sum_rank0_length : int Sum of lengths of the first segments of all walks. Defaults to 0. input_genome_size : int Cumulative size of all paths (sum of segment lengths along each walk). Defaults to 0. walks_info : List[Dict[str, Any]] List of dictionaries, each containing info for a walk (Path name, Sequence length, Num Segments). Defaults to an empty list. inverted_links_count : int Count of links where orientations differ (e.g., S1+ -> S2-). Defaults to 0. negative_links_count : int Count of links where both segments have negative orientation (S1- -> S2-). Defaults to 0. self_links_count : int Count of links where a segment links to itself (S1 -> S1). Defaults to 0. isolated_segments : Set[str] Set of segment IDs that have no links connected to them. Initialized with all segments, then linked ones removed. Defaults to an empty set. shared_executor : Optional[ThreadPoolExecutor] Executor for running synchronous tasks in threads. Auto-initialized. progress : Optional[Progress] Rich Progress instance for displaying progress. Auto-initialized. line_type_counts : Counter Counts of each GFA line type (H, S, L, W, P, C, E, U). Auto-initialized. header_gfa_file : Optional[Path] Path to where the GFA header is saved. Auto-initialized. stats_file : Optional[Path] Path to where GFA statistics are saved. Auto-initialized. db_file_path : Optional[Path] Path to the SQLite database file for links. Auto-initialized. RE_ORIENTED_SEG_GT_LT : re.Pattern Compiled regular expression for parsing oriented segments using '>' and '<'. Auto-initialized. RE_ORIENTED_SEG_PLUS_MINUS : re.Pattern Compiled regular expression for parsing oriented segments using '+' and '-'. Auto-initialized. """ gfa_path: Path threads: int = 1 logger: logging.Logger = field( default_factory=lambda: logging.getLogger("GraTools") ) gfa_name: Optional[str] = None version: Optional[str] = None header_gfa: List[str] = field(default_factory=list, repr=True) sample_reference: Optional[str] = None bam_segments_file: Optional[Path] = None # Dictionaries populated during parsing dict_samples_chrom: defaultdict = field( default_factory=lambda: defaultdict(OrderedDict), repr=False # Key: sample, Value: OrderedDict[chrom, list_of_fragments] ) dict_segments_size: defaultdict = field( default_factory=lambda: defaultdict(int), repr=False # Key: seg_id, Value: length ) dict_segments_samples: defaultdict = field( default_factory=lambda: defaultdict(list), repr=False # Key: seg_id, Value: list of "sample;chrom;haplo" ) # dict_samples_bed is related to AsyncBedWriter output, not directly populated here in GFA parsing # It seems more like a tracker of where AsyncBedWriter IS writing files. # The GFA class _generates_ data for AsyncBedWriter. dict_samples_bed: defaultdict = field( # This seems to track which BED files are expected/created by AsyncBedWriter. default_factory=lambda: defaultdict(OrderedDict), repr=False # Key: sample, Value: OrderedDict[chrom, path_to_bed_for_chrom_walk] ) # Paths works_path: Optional[Path] = None bed_path: Optional[Path] = None bam_path: Optional[Path] = None found_minigraph: bool = False index_links: bool = False # Flag to control link indexing db_links: Optional[AsyncGfaDatabase] = None # Instantiated in post_init # Statistics counters segment_count: int = 0 total_segment_length: int = 0 link_count: int = 0 degrees: defaultdict = field(default_factory=lambda: defaultdict(int), repr=False) # segment_id -> degree walks_count: int = 0 max_walk_rank: int = 0 # Max number of segments in a walk sum_rank0_length: int = 0 # Sum of lengths of first segments of walks input_genome_size: int = 0 walks_info: List[Dict[str, Any]] = field(default_factory=list, repr=False) inverted_links_count: int = 0 negative_links_count: int = 0 self_links_count: int = 0 isolated_segments: set = field(default_factory=set, repr=False) # Store segment IDs # Internal operational attributes (initialized in post_init) shared_executor: Optional[ThreadPoolExecutor] = None progress: Optional[Progress] = None line_type_counts: Counter = field(default_factory=Counter) header_gfa_file: Optional[Path] = None stats_file: Optional[Path] = None db_file_path: Optional[Path] = None RE_ORIENTED_SEG_GT_LT: re.Pattern = field(init=False, repr=False) RE_ORIENTED_SEG_PLUS_MINUS: re.Pattern = field(init=False, repr=False) def __post_init__(self) -> None: """ Initializes paths, directories, logger, database, and then triggers GFA processing. """ start_time = datetime.now() if not self.gfa_path.exists() or not self.gfa_path.is_file(): self.logger.error( f"GFA file '{self.gfa_path}' does not exist or is not a file." ) raise FileNotFoundError(f"GFA file not found: {self.gfa_path}") # Fail fast # Determine GFA name # Using removesuffix for cleaner extension removal name_to_process = self.gfa_path.name if name_to_process.endswith(".gfa.gz"): self.gfa_name = name_to_process.removesuffix(".gfa.gz") elif name_to_process.endswith(".gfa"): self.gfa_name = name_to_process.removesuffix(".gfa") else: self.gfa_name = self.gfa_path.stem # Fallback, might include partial extensions self.logger.warning( f"GFA file '{self.gfa_path.name}' has an unusual extension. Using stem: '{self.gfa_name}'.") # Setup working paths self.works_path = self.gfa_path.parent / f"{self.gfa_name}_GraTools_INDEX" self.works_path.mkdir(exist_ok=True, parents=True) self.bed_path = self.works_path / "bed_files" self.bed_path.mkdir(exist_ok=True, parents=True) self.bam_path = self.works_path / "bam_files" self.bam_path.mkdir(exist_ok=True, parents=True) self.bam_segments_file = self.bam_path / f"{self.gfa_name}.bam" self.header_gfa_file = self.works_path / f"header_{self.gfa_name}.txt" self.stats_file = self.works_path / f"stats_{self.gfa_name}.txt" self.db_file_path = self.works_path / f"links_{self.gfa_name}.db" # Precompile regex for optimization self.RE_ORIENTED_SEG_GT_LT = re.compile(r"([<>])([^<>]+)") self.RE_ORIENTED_SEG_PLUS_MINUS = re.compile(r"(.+?)([+-])") # Initialize ThreadPoolExecutor. self.shared_executor = ThreadPoolExecutor( max_workers=self.threads if self.threads > 0 else None) # None for default workers if self.index_links: self.db_links = AsyncGfaDatabase( db_file=self.db_file_path, # Use the defined path timeout=5000.0 / 1000.0, # Convert ms to seconds if timeout arg is seconds (original was 5000, assuming ms for busy_timeout) # AsyncGfaDatabase timeout is in seconds for aiosqlite.connect ) self.progress = Progress( TextColumn("[bold blue]{task.description}"), BarColumn(), "[progress.percentage]{task.percentage:>3.1f}%", "{task.completed}/{task.total} lines", TimeElapsedColumn(), TimeRemainingColumn(), refresh_per_second=1, # Higher refresh can be costly transient=True, # Progress bar disappears after completion console=shared_console ) # Trigger GFA processing and subsequent steps self.run() # This executes the main parsing and post-processing logic # These are called after self.run() finishes its async and threaded tasks. self.save_header() self.tag_bam() # Assumes self.bam_segments_file is created by parse_gfa (via _parse_segment) self.save_samples_chrom() # Uses dict_samples_chrom populated by _parse_walk_line self.compute_statistics() # Uses data populated during parsing start_time = datetime.now() elapsed_time = datetime.now() - start_time self.logger.info(f"GFA processing and indexing for '{self.gfa_name}' completed in: {elapsed_time}")
[docs] def save_header(self) -> None: """Save the GFA header lines (H lines) to a text file.""" if not self.header_gfa_file: self.logger.error("Header GFA file path not set. Cannot save header.") return self.logger.info(f"Saving GFA header to '{self.header_gfa_file.name}'") try: with open(self.header_gfa_file, "w", encoding="utf-8") as f: f.write("\n".join(self.header_gfa)) except IOError as e: self.logger.error(f"Failed to save header file '{self.header_gfa_file}': {e}")
[docs] def save_samples_chrom(self) -> None: """ Save sample-chromosome-fragment information to 'samples_chrom.txt'. This data is derived from GFA Walk (W) lines. """ output_path = self.works_path / "samples_chrom.txt" self.logger.info(f"Saving sample-chromosome fragment data to '{output_path.name}'") try: with open(output_path, "w", encoding="utf-8") as f: # Sort samples for consistent output for sample in natsorted(self.dict_samples_chrom.keys()): dict_chroms = self.dict_samples_chrom[sample] # Sort chromosomes for consistent output for chrom in natsorted(dict_chroms.keys()): fragments_list = dict_chroms[chrom] # list of "start\tstop" strings for fragment_info_str in fragments_list: # fragment_info_str is "start\tstop" f.write(f"{sample}\t{chrom}\t{fragment_info_str}\n") except IOError as e: self.logger.error(f"Failed to save samples_chrom.txt file '{output_path}': {e}")
def _process_single_bed(self, bed_path: Path, task_id: TaskID) -> None: """ Sorts a single BED file in place using BedTool. Updates a Rich progress bar task. This method is intended to be run in a separate thread or process. """ self.logger.debug(f"Sorting BED file: {bed_path.name}") try: tmp_path = bed_path.with_suffix(".sorted.bed") sample_bed = BedTool(bed_path.as_posix()).sort(output=tmp_path.as_posix()) tmp_path.replace(bed_path) self.logger.debug(f"Finished sorting BED file: {bed_path.name}") except Exception as e: # Bedtools can raise various errors self.logger.error(f"Error sorting BED file '{bed_path.name}': {e}", exc_info=True) # How to signal error to progress? Maybe update description. if self.progress: self.progress.update(task_id, description=f"[red]Error sorting {bed_path.name}") return # Exit on error for this BED file if self.progress: # Check if progress object exists (it should if called from run()) self.progress.update(task_id, advance=1) async def _read_gfa_lines_async(self): """ Asynchronously reads lines from a GFA file, handling both plain text and gzipped files. Yields lines one by one. """ # REFACTOR_SUGGESTION: Consider making file opening mode ('rt' vs 'rb') more explicit. # For gzip, 'rt' implies text mode decoding. For aiofiles, 'r' implies text mode with default encoding. # Specifying encoding='utf-8' is often a good practice. # Check if GFA path ends with .gz (more robust than checking suffix list) is_gzipped = self.gfa_path.name.endswith(".gz") if is_gzipped: self.logger.debug(f"Reading gzipped GFA file: {self.gfa_path.name} using gzip in thread.") # gzip.open is synchronous, so run it in a thread to not block asyncio loop try: # asyncio.to_thread requires Python 3.9+ # For older Python, use loop.run_in_executor(None, gzip.open, self.gfa_path, "rt", encoding="utf-8") f = await asyncio.to_thread(gzip.open, self.gfa_path, "rt", encoding="utf-8") except Exception as e: self.logger.error(f"Failed to open gzipped file {self.gfa_path.name} with gzip: {e}") return # Stop generation else: self.logger.debug(f"Reading plain text GFA file: {self.gfa_path.name} using aiofiles.") try: f = await aiofiles.open(self.gfa_path, mode="r", encoding="utf-8") except Exception as e: self.logger.error(f"Failed to open plain text file {self.gfa_path.name} with aiofiles: {e}") return # Stop generation try: # Iterate over lines from the opened file object (f) # For gzipped files opened via asyncio.to_thread, this iteration happens within the thread. # For aiofiles, `async for line in f` is the standard async iteration. if hasattr(f, "__aiter__"): # Check if it's an async iterator (aiofiles) async for line in f: yield line # Add a small sleep periodically if line processing is CPU intensive AND inside the async for loop await asyncio.sleep(0) # uncomment if GFA processing per line becomes heavy else: # Synchronous iterator (gzip.open via to_thread) for line in f: # This loop runs inside the thread for gzip yield line # No asyncio.sleep(0) here as it would sleep in the worker thread, not yield to event loop. # The `await asyncio.to_thread` call itself yields. finally: self.logger.debug(f"Closing GFA file: {self.gfa_path.name}") if hasattr(f, "close") and callable(f.close): if is_gzipped: # If from gzip.open await asyncio.to_thread(f.close) else: # If from aiofiles.open (which returns an async file object) await f.close()
[docs] async def parse_gfa(self) -> None: """ Parses the GFA file line by line, processing headers, segments, links, and walks. Segments are written to a BAM file. Links are (optionally) stored in an async SQLite DB. Walk information is used to generate data for BED files, written by AsyncBedWriter. """ # Standard GFA header for output BAM if input GFA has no SQ lines. # Pysam's AlignmentFile requires a header. If GFA S lines imply @SQ, it would be better. # For now, a minimal header. bam_header = { 'HD': {'VN': '1.8', 'SO': 'unsorted'}, # SO: Coordinate or unsorted, depending on S lines 'SQ': [] # Placeholder for sequence dictionary. Ideally, populate from S lines if reference context is known. } # Pre-calculate total lines for progress bar (can be slow for very large files) total_lines = 0 try: open_fn = gzip.open if self.gfa_path.name.endswith(".gz") else open with open_fn(self.gfa_path, "rt", encoding="utf-8") as f_count: total_lines = sum(1 for _ in f_count) self.logger.info(f"Starting GFA parsing for: '{self.gfa_name}' with {total_lines} lines") except Exception as e: self.logger.warning(f"Could not count lines in GFA file for progress bar: {e}. Progress may be inaccurate.") total_lines = 1 # Avoid division by zero if progress bar uses it # Initialize AsyncBedWriter bed_writer = AsyncBedWriter( bed_dir=self.bed_path, batch_size=100000, # Tunable parameter progress=self.progress # Pass Rich progress instance if needed by BedWriter ) bed_writer.start() # Start its writer loop # Connect to AsyncGfaDatabase if indexing links if self.index_links and self.db_links: await self.db_links.connect() links_batch: List[LinkInfo] = [] # Use List for type hinting, will store LinkInfo link_batch_size = 100000 # How many links to batch before writing to DB # Add task to Rich Progress parse_task_id = self.progress.add_task( "GFA parsing...", total=total_lines, visible=True ) # Open BAM file for writing segments # pysam.AlignmentFile is synchronous. Writes happen in the main async thread. # If BAM writing is slow, it can block the event loop. # OPTIMIZATION_SUGGESTION: For very high throughput GFA S lines, # consider batching AlignedSegment objects and writing them in a separate thread # using `await asyncio.to_thread(bam_file_out.write, batched_segment)`. try: with AlignmentFile( self.bam_segments_file.as_posix(), "wb", header=bam_header, threads=1 # Use self.threads for multi-threading if pysam supports for write ) as bam_file_out, self.progress: # `with self.progress` starts/stops the Rich Progress context current_line_type = None lines_processed_count = 0 async for line_content in self._read_gfa_lines_async(): lines_processed_count += 1 self.progress.advance(parse_task_id) # Advance progress for each line line_content = line_content.strip() # Strip newline and surrounding whitespace if not line_content or line_content.startswith("#"): # Skip empty or comment lines continue fields = line_content.split("\t") line_type = fields[0] if line_type != current_line_type: current_line_type = line_type self.progress.update( parse_task_id, description=f"GFA parsing: {current_line_type}" ) self.line_type_counts[line_type] += 1 match line_type: case "H": # Header line self.header_gfa.append(line_content) # Basic header parsing for version and reference sample (if present) # Example H line: H VN:Z:1.0 RS:Z:ref_sample for field in fields[1:]: if field.startswith("VN:Z:"): self.version = field[5:] elif field.startswith( "RS:Z:"): # Reference Sample tag (custom or specific GFA variant?) self.sample_reference = field[5:] case "S": # Segment line self._parse_segment(fields, bam_file_out) # Pass fields list case "L": # Links line self.link_count += 1 if self.index_links and self.db_links: link_info = self._parse_link_line(fields) # Pass fields list if link_info: links_batch.append(link_info) if len(links_batch) >= link_batch_size: await self.db_links.batch_insert_links(list(links_batch)) # Pass a copy links_batch.clear() await asyncio.sleep(0) # Yield control after a DB batch case "W": # Walk line # W <sample_name> <haplotype_index> <chr_name> <chr_start> <chr_stop> <segment_path> # For GFA2 OrderedGroup (OG lines), parsing would be different. Assuming GFA1 W. parse_result = self._parse_walk_line(fields) # Pass fields list if parse_result: sample_name, bed_lines_for_sample = parse_result if bed_lines_for_sample: # Ensure there are lines to enqueue bed_writer.enqueue(sample_name, bed_lines_for_sample) # Yield control if many walks are processed rapidly if self.walks_count % 10 == 0: # Every 2 walks await asyncio.sleep(0) # Other GFA line types (P - Path, C - Containment, E - Edge (GFA2), U - Gap (GFA2), etc.) # Add parsing logic here if needed. For now, they are counted by line_type_counts. # Periodically yield control to event loop, especially if line processing is fast if lines_processed_count % 100 == 0: # Every 100 lines await asyncio.sleep(0) except Exception as e: self.logger.error(f"Error during GFA parsing main loop: {e}", exc_info=True) # Ensure progress bar is stopped or marked as errored if possible self.progress.update(parse_task_id, description=f"[red]Error during GFA parsing: {e}", completed=lines_processed_count) # Proceed to shutdown, resources might be in an inconsistent state. finally: # Finalize operations after loop (or if error occurred) self.progress.remove_task(parse_task_id) # Or stop if not transient self.logger.info(f"Finished reading GFA file. Processed {lines_processed_count} lines.") if self.found_minigraph: self.logger.warning( "Sample 'MINIGRAPH' (or similar) detected and ignored. " "See Cactus issue https://github.com/ComparativeGenomicsToolkit/cactus/pull/1050 for context." ) # Flush any remaining links to the database if self.index_links and self.db_links and links_batch: self.logger.debug(f"Inserting final batch of {len(links_batch)} links into database.") await self.db_links.batch_insert_links(list(links_batch)) # Shutdown AsyncGfaDatabase (waits for queue, creates indexes, closes connection) if self.db_links: self.logger.info( f"Closing links database and saved to '{self.db_file_path.name if self.db_file_path else 'N/A'}'.") await self.db_links.close() # Shutdown AsyncBedWriter (waits for queue, flushes all buffers, closes files) self.logger.info("Shutting down BED writer...") await bed_writer.shutdown()
[docs] def run(self): """ Synchronous entry point to orchestrate GFA parsing and BED file sorting. Sets up an asyncio event loop and runs the asynchronous `parse_gfa` method. Then, sorts the generated BED files using a ThreadPoolExecutor. """ # Create and configure a new asyncio event loop loop = asyncio.new_event_loop() # uvloop policy is set globally, so new_event_loop() should pick it up. asyncio.set_event_loop(loop) # loop = asyncio.get_event_loop() # Get current event loop (uvloop should be set) if self.shared_executor: # Ensure executor is set for the loop loop.set_default_executor(self.shared_executor) try: # Run the asynchronous GFA parsing loop.run_until_complete(self.parse_gfa()) self.logger.info("Async GFA parsing complete. Proceeding to sort BED files.") # Collect paths of BED files to be sorted # This assumes dict_samples_bed is populated correctly by where AsyncBedWriter saves files. # TODO: AsyncBedWriter saves as self.bed_dir / f"{sample}.bed" # So, we need to find all .bed files in self.bed_path that were potentially created. # A more robust way: AsyncBedWriter could return a list of files it actually wrote/flushed. # Shutdown shared thread executor if self.shared_executor: self.logger.debug("Shutting down shared ThreadPoolExecutor.") self.shared_executor.shutdown(wait=True) # Wait for submitted tasks to complete loop.close() # Only if this `run` method owns the loop lifecycle entirely. # Assuming dict_samples_chrom keys are the samples for which BEDs might exist: beds_to_sort = [] if self.bed_path and self.bed_path.exists(): for sample_name in self.dict_samples_chrom.keys(): # Samples derived from W lines potential_bed_file = self.bed_path / f"{sample_name}.bed" if potential_bed_file.exists() and potential_bed_file.stat().st_size > 0: beds_to_sort.append(potential_bed_file) if not beds_to_sort: self.logger.info("No BED files found or specified for sorting.") else: self.logger.info(f"Found {len(beds_to_sort)} BED files to sort.") sort_task_id = self.progress.add_task( f"Sorting {len(beds_to_sort)} BED files", total=len(beds_to_sort), visible=True, # Make sure progress bar is visible ) bed_sort_workers = min(self.threads, len(beds_to_sort)) with ThreadPoolExecutor(max_workers=bed_sort_workers) as bed_sort_specific_executor: with self.progress: futures = [ bed_sort_specific_executor.submit( self._process_single_bed, bed_file_path, sort_task_id ) for bed_file_path in beds_to_sort ] # Attendre que toutes les tâches soumises soient terminées # `wait` bloque jusqu'à ce que les futures soient complétées. done, not_done = wait(futures) if not_done: self.logger.warning(f"{len(not_done)} BED sorting tasks did not complete.") for future_not_done in not_done: try: future_not_done.result(timeout=0.1) # Tenter d'obtenir une exception except Exception as e: self.logger.error(f"Task error for uncompleted sort: {e}", exc_info=True) self.progress.remove_task(sort_task_id) self.logger.info("All BED file sorting tasks complete.") except Exception as e: self.logger.error(f"Error during GFA.run execution: {e}", exc_info=True) finally: self.logger.debug("GFA.run() finished.")
def _parse_walk_line(self, fields: List[str]) -> Optional[Tuple[str, List[str]]]: """ Parses a GFA Walk line (W line) to extract path information and generate BED-like entries. Updates internal dictionaries related to samples, chromosomes, and segment occurrences. Parameters ---------- fields : List[str] A list of tab-separated fields from a W line. Expected GFA1 format: W <sample> <hap_idx> <chr> <start> <stop> <path_spec> Example path_spec: >seg1A+>seg1B->seg2A+ (GFA1 allows segment names with orientations) >s1>s2<s3 (another common style where orientation is part of path_spec) Returns ------- Optional[Tuple[str, List[str]]] A tuple (sample_name, list_of_bed_lines_for_this_walk) if successful. Each bed_line string includes a newline. Returns None if parsing fails or line is skipped (e.g., MINIGRAPH). """ try: # GFA1 W line: W <Samp> <HapIdx> <SeqId> <Start> <End> <Path> # For GFA specification, check official docs. This parsing assumes a common interpretation. _type_line = fields[0] # "W" sample_name = fields[1] haplotype_index_str = fields[2] # Can be numeric or other ID chromosome_name = fields[3] chromosome_start_str = fields[4] # 0-based or 1-based depending on GFA source convention chromosome_stop_str = fields[5] # End coordinate gfa_path_specification = fields[6] # e.g., "s1+s2-s3+" or ">s1<s2>s3" except IndexError: self.logger.warning(f"Malformed W line (not enough fields): '{fields}'. Skipping.") return None if "MINIGRAPH" in sample_name.upper(): # Case-insensitive check self.found_minigraph = True self.logger.debug(f"Skipping W line for 'MINIGRAPH' sample: {fields}") return None # Skip processing for MINIGRAPH sample # Update sample-chromosome structure # Using setdefault is Pythonic for initializing if key is missing sample_chrom_list = self.dict_samples_chrom[sample_name].setdefault(chromosome_name, []) sample_chrom_list.append(f"{chromosome_start_str}\t{chromosome_stop_str}") current_segment_start_on_chr = int(chromosome_start_str) # Current genomic position for BED line # Information for SW tag in BAM (segment_id -> list of "sample;chrom;haplo") sample_chromosome_haplotype_id = f"{sample_name};{chromosome_name};{haplotype_index_str}" oriented_segments_in_path: List[Tuple[str, str]] = [] if ">" in gfa_path_specification or "<" in gfa_path_specification: matches = self.RE_ORIENTED_SEG_GT_LT.findall(gfa_path_specification) # Use pre-compiled regex for orient_char_map, seg_id in matches: oriented_segments_in_path.append(("+" if orient_char_map == ">" else "-", seg_id)) else: matches = self.RE_ORIENTED_SEG_PLUS_MINUS.findall(gfa_path_specification) # Use pre-compiled regex for seg_id, orient_char in matches: oriented_segments_in_path.append((orient_char, seg_id)) if not oriented_segments_in_path: self.logger.warning(f"Could not parse path spec in W line: '{gfa_path_specification}'. Skipping.") return None num_segments_in_walk = len(oriented_segments_in_path) self.walks_count += 1 self.max_walk_rank = max(self.max_walk_rank, num_segments_in_walk) current_walk_total_length_bp = 0 if oriented_segments_in_path: # Check if list is not empty first_seg_id = oriented_segments_in_path[0][1] # Directly use self.dict_segments_size, local alias might not give much here first_seg_len = self.dict_segments_size.get(first_seg_id, 0) self.sum_rank0_length += first_seg_len # Pre-allocate list for BED lines if num_segments_in_walk is large # Threshold for pre-allocation can be tuned, e.g., if > 1000 segments if num_segments_in_walk > 1000: # Example threshold bed_lines_for_this_walk: List[str] = [""] * num_segments_in_walk use_preallocation = True else: bed_lines_for_this_walk: List[str] = [] use_preallocation = False bed_field_5_fragment_id = f"{chromosome_start_str}:{chromosome_stop_str}" # bed_field_6_haplotype_id = haplotype_index_str # Already have haplotype_index_str # Store direct references to dictionaries to potentially reduce lookup overhead within the tight loop. # This is a micro-optimization and might not yield significant gains unless N is extremely large. _dict_segments_samples = self.dict_segments_samples _dict_segments_size = self.dict_segments_size for i, (segment_orientation_char, segment_id) in enumerate(oriented_segments_in_path): _dict_segments_samples[segment_id].append(sample_chromosome_haplotype_id) segment_length = _dict_segments_size.get(segment_id, 0) # Warning for 0-length is already there, keep it. if segment_length == 0: self.logger.warning( f"Segment '{segment_id}' in walk for sample '{sample_name}' has zero length " f"or size not found. BED entry may be inaccurate." ) current_walk_total_length_bp += segment_length segment_end_on_chr = current_segment_start_on_chr + segment_length # Construct BED line (f-strings are generally efficient) # No significant optimization here typically, unless constructing millions of unique strings. # Joining from parts `"\t".join([...])` can sometimes be faster if parts are reused, but not here. bed_line = ( f"{chromosome_name}\t" f"{current_segment_start_on_chr}\t{segment_end_on_chr}\t" f"{segment_orientation_char}{segment_id}\t" # oriented_seg_id_for_bed f"{bed_field_5_fragment_id}\t" f"{haplotype_index_str}\n" # Directly use haplotype_index_str ) if use_preallocation: bed_lines_for_this_walk[i] = bed_line else: bed_lines_for_this_walk.append(bed_line) # Update start position for the next segment in the walk current_segment_start_on_chr = segment_end_on_chr # Update total genome size covered by all walks self.input_genome_size += current_walk_total_length_bp # Store summary information for this walk (for statistics) self.walks_info.append({ "PathName": f"{sample_name}_{chromosome_name}_{haplotype_index_str}", # Unique path identifier "SequenceLength_bp": current_walk_total_length_bp, "NumberOfSegments": num_segments_in_walk, }) return sample_name, bed_lines_for_this_walk def _parse_link_line(self, fields: List[str]) -> Optional[LinkInfo]: """ Parses a GFA Link line (L line) to extract connection information between two segments. Updates link-related statistics. Parameters ---------- fields : List[str] A list of tab-separated fields from an L line. Expected GFA1 format: L <seg1> <orient1> <seg2> <orient2> <overlap_CIGAR> Returns ------- Optional[LinkInfo] A LinkInfo namedtuple containing parsed link data if successful, else None. """ try: _type_line = fields[0] # "L" seg_id_1 = fields[1] orient_char_1 = fields[2] # "+" or "-" seg_id_2 = fields[3] orient_char_2 = fields[4] # "+" or "-" # overlap_cigar = fields[5] # e.g., "0M", "10M", etc. Not used in LinkInfo current def. except IndexError: self.logger.warning(f"Malformed L line (not enough fields): '{fields}'. Skipping.") return None # Convert orientation characters to numeric values (+1 for '+', -1 for '-') # This is a common convention if orientations need to be multiplied or compared numerically. orient_seg_1_numeric = 1 if orient_char_1 == "+" else -1 orient_seg_2_numeric = 1 if orient_char_2 == "+" else -1 # Update link statistics self.link_count += 1 if seg_id_1 == seg_id_2: self.self_links_count += 1 if orient_seg_1_numeric * orient_seg_2_numeric < 0: # e.g., (+1 * -1) = -1 (orientations differ) self.inverted_links_count += 1 if orient_seg_1_numeric == -1 and orient_seg_2_numeric == -1: # Both negative self.negative_links_count += 1 # Update segment degrees self.degrees[seg_id_1] += 1 self.degrees[seg_id_2] += 1 # Remove linked segments from the set of isolated segments self.isolated_segments.discard(seg_id_1) self.isolated_segments.discard(seg_id_2) # Create LinkInfo tuple. The `orient_key` fields seem to duplicate `orient_seg` fields. # If they have a distinct meaning, LinkInfo definition or parsing might need adjustment. return LinkInfo( seg_id_1, orient_seg_1_numeric, orient_seg_1_numeric, # Using numeric orientation for keys too seg_id_2, orient_seg_2_numeric, orient_seg_2_numeric ) def _parse_segment(self, fields: List[str], bam_file_out: AlignmentFile) -> None: """ Parses a GFA Segment line (S line) and writes it as an AlignedSegment to a BAM file. Updates segment-related statistics. Parameters ---------- fields : List[str] A list of tab-separated fields from an S line. Expected GFA1 format: S <seg_id> <sequence> [optional_tags...] Example optional tag: LN:i:<length> (though length is usually derived from sequence) bam_file_out : pysam.AlignmentFile Open BAM file object for writing. """ try: _type_line = fields[0] # "S" seg_id = fields[1] sequence = fields[2] # Sequence string # But here, it seems sequence is expected directly. optional_tags_list = fields[3:] # List of "TAG:TYPE:VALUE" strings except IndexError: self.logger.warning(f"Malformed S line (not enough fields): '{fields}'. Skipping.") return seg_length = len(sequence) # Update statistics self.dict_segments_size[seg_id] = seg_length self.segment_count += 1 self.total_segment_length += seg_length self.isolated_segments.add(seg_id) # Add initially, links will remove it if connected # Create AlignedSegment object for BAM segment_for_bam = AlignedSegment() segment_for_bam.query_name = seg_id segment_for_bam.query_sequence = sequence segment_for_bam.query_qualities = qualitystring_to_array("*" * seg_length) segment_for_bam.flag = 4 # Unmapped segment_for_bam.reference_id = -1 # No reference segment_for_bam.reference_start = 0 # 0-based unmapped position segment_for_bam.mapping_quality = 0 # Unmapped # CIGAR: For unmapped segments, often not set or set to reflect sequence length if needed by tools. # If sequence exists, a CIGAR like "XM" (X=length) might be used. Pysam might handle this. # For simplicity, if sequence is present, a simple "M" CIGAR. segment_for_bam.cigartuples = [ (0, seg_length)] if seg_length > 0 else None # e.g., ( pysam.CMATCH, length_of_sequence ) # Store GFA optional tags in a custom BAM tag, e.g., ZG:Z:<concatenated_gfa_tags> # Original code uses "SU" tag. Let's stick to that. # The "SU" tag seems to store a comma-separated string of the GFA optional tags. if optional_tags_list: gfa_tags_str = ",".join(optional_tags_list) segment_for_bam.set_tag("SU", gfa_tags_str, value_type='Z') # 'Z' for string # Write to BAM file # This is a synchronous call. If it's slow, it blocks the async event loop. try: bam_file_out.write(segment_for_bam) except Exception as e: # Catch errors from pysam write (e.g., header issues, malformed segment) self.logger.error(f"Error writing segment '{seg_id}' to BAM: {e}", exc_info=True)
[docs] def tag_bam(self) -> None: """ Tags the generated BAM segments file with sample walk information (SW tag). This uses the `GratoolsBam` class to perform the tagging. The original BAM segments file is overwritten with the tagged version. """ self.logger.info(f"Initiating BAM tagging for '{self.bam_segments_file.name}'.") if not self.bam_segments_file or not self.bam_segments_file.exists(): self.logger.error("BAM segments file not found. Cannot perform tagging.") return if not self.dict_segments_samples: self.logger.warning("dict_segments_samples is empty. No SW tags will be added.") # Still proceed to index, as an untagged but indexed BAM might be expected. # Instantiate GratoolsBam for tagging operation bam_manager = GratoolsBam( bam_path=self.bam_segments_file, threads=self.threads, logger=self.logger, # Pass parent logger tagging=False # Indicates that indexing might be needed after tagging ) # The tag() method is expected to create a new tagged BAM, # then replace the original self.bam_segments_file. # It should also handle re-indexing if `tagging=True` implies index creation/update. try: # The returned path from tag() should be the path to the (now tagged) bam_segments_file updated_bam_path = bam_manager.tag(dict_segments_samples=self.dict_segments_samples) if updated_bam_path != self.bam_segments_file: self.logger.warning(f"BAM tagging returned a new path '{updated_bam_path}', " f"but expected overwrite of '{self.bam_segments_file.name}'. Check GratoolsBam.tag logic.") except Exception as e: self.logger.error(f"Error during BAM tagging process: {e}", exc_info=True)
def _get_connected_component(self, start_seg: str, cursor: sqlite3.Cursor) -> Set[str]: """ Executes a recursive SQL query via sqlite3 to retrieve all segments connected to `start_seg`, using an existing database cursor. This assumes bi-directional links (if A links to B, B effectively links to A for component finding). Parameters ---------- start_seg : str The starting segment ID for component traversal. cursor : sqlite3.Cursor An active SQLite database cursor to use for the query. Returns ------- Set[str] A set of segment IDs belonging to the same connected component as `start_seg`. """ # The query finds all segments reachable from start_seg. # It considers links in both directions (seg_id_1 -> seg_id_2 and seg_id_2 -> seg_id_1) # by UNIONing them in the recursive part. query = """ WITH RECURSIVE connected_component(segment_id) AS ( VALUES(?) -- Base case: the starting segment UNION -- Recursively find segments linked from current component segments SELECT links.seg_id_2 FROM links JOIN connected_component ON links.seg_id_1 = connected_component.segment_id UNION -- Also consider links in the other direction SELECT links.seg_id_1 FROM links JOIN connected_component ON links.seg_id_2 = connected_component.segment_id ) SELECT segment_id FROM connected_component; """ # OPTIMIZATION_SUGGESTION: For very large graphs, ensure SQLite's query planner handles # this recursive CTE efficiently. Indexing on (seg_id_1, seg_id_2) and (seg_id_2, seg_id_1) # or individual indexes on seg_id_1 and seg_id_2 (as created by AsyncGfaDatabase) is crucial. try: cursor.execute(query, (start_seg,)) rows = cursor.fetchall() return {row[0] for row in rows} # Convert list of tuples to set of IDs except sqlite3.Error as e: self.logger.error(f"SQLite error getting connected component for '{start_seg}': {e}", exc_info=True) return {start_seg} # Return at least the start segment on error
[docs] def compute_statistics(self) -> Dict[str, Any]: # Return type more specific if possible """ Computes various statistics about the parsed GFA graph, including segment counts, lengths, link properties, and connectivity (if link indexing is enabled). Saves these statistics to a file. Returns ------- Dict[str, Any] A dictionary containing all computed statistics, categorized. """ self.logger.info("Computing GFA graph statistics...") current_progress = self.progress # Define number of steps for progress bar # REFACTOR_SUGGESTION: Could make this more dynamic or a constant. total_steps_stats = 7 stats_task_id = current_progress.add_task( "[bold blue]Computing graph statistics...", total=total_steps_stats ) stats_results: Dict[str, Any] = {} with current_progress: # Manage progress bar display try: # Step 1: Basic Graph Stats current_progress.update( stats_task_id, description="[cyan]Stats: Basic Graph Metrics", advance=0 # No advance yet ) avg_segment_length = ( self.total_segment_length / self.segment_count if self.segment_count else 0.0 ) # Avg degree = (2 * num_edges) / num_nodes for undirected graph avg_degree = ( (2 * self.link_count) / self.segment_count if self.segment_count else 0.0 ) max_degree = max(self.degrees.values()) if self.degrees else 0 # Node-Edge Ratio (or segment-link ratio) ne_ratio = self.segment_count / self.link_count if self.link_count else 0.0 graph_size_bp = self.total_segment_length # Total length of all unique segments # input_genome_size is sum of segment lengths along all paths (can be > graph_size_bp) compress_ratio = self.input_genome_size / graph_size_bp if graph_size_bp else 0.0 current_progress.update(stats_task_id, advance=1) # Step 2: Segment Length Stats current_progress.update(stats_task_id, description="[cyan]Stats: Segment Lengths") node_lengths = list(self.dict_segments_size.values()) median_node_length = float(median(node_lengths)) if node_lengths else 0.0 avg_top5_percent_length = 0.0 median_top5_percent_length = 0.0 if node_lengths: # Calculate top 5% longest nodes num_top_5_percent = max(1, int(0.05 * len(node_lengths))) top_5_percent_nodes_lengths = sorted(node_lengths, reverse=True)[:num_top_5_percent] if top_5_percent_nodes_lengths: avg_top5_percent_length = float(mean(top_5_percent_nodes_lengths)) median_top5_percent_length = float(median(top_5_percent_nodes_lengths)) # Histogram of node lengths # Bins: [0-500), [500-1000), [1000-2000), [2000-inf) length_bins = [0, 500, 1000, 2000, float("inf")] bin_labels = ["0-499bp", "500-999bp", "1000-1999bp", "2000+bp"] hist_counts, _ = histogram(node_lengths, bins=length_bins) segment_length_bin_counts = dict(zip(bin_labels, map(int, hist_counts))) # Ensure int for counts current_progress.update(stats_task_id, advance=1) # Step 3: Similarity (num unique samples per seg) & Depth (total occurrences per seg) current_progress.update(stats_task_id, description="[cyan]Stats: Segment Similarity/Depth") similarities = [] # Number of unique samples per segment depths = [] # Total occurrences of a segment across all walks/samples for seg_id, sample_occurrences_list in self.dict_segments_samples.items(): # sample_occurrences_list contains "sample;chrom;haplo" strings unique_samples_for_seg = set(occ.split(";")[0] for occ in sample_occurrences_list) similarities.append(len(unique_samples_for_seg)) depths.append(len(sample_occurrences_list)) sim_mean = float(mean(similarities)) if similarities else 0.0 sim_median = float(median(similarities)) if similarities else 0.0 sim_std = float(std(similarities)) if similarities else 0.0 depth_mean = float(mean(depths)) if depths else 0.0 depth_median = float(median(depths)) if depths else 0.0 depth_std = float(std(depths)) if depths else 0.0 current_progress.update(stats_task_id, advance=1) # Step 4: Graph Density current_progress.update(stats_task_id, description="[cyan]Stats: Graph Density") # Density for an undirected graph: 2*L / (N*(N-1)) where L=links, N=segments if self.segment_count > 1: graph_density = (2 * self.link_count) / ( self.segment_count * (self.segment_count - 1) ) else: graph_density = 0.0 # Undefined or 0 for single/no node graphs current_progress.update(stats_task_id, advance=1) # Step 5: Connectivity (requires SQLite DB of links) current_progress.update(stats_task_id, description="[cyan]Stats: Connectivity (DB query)") num_connected_components = 0 largest_cc_size_bp = 0 # Size in base pairs of the largest CC # disconnected_components_info = [] # List of (size_bp, num_nodes) for smaller CCs num_disconnected_components_count = 0 # Number of CCs excluding the largest one total_length_disconnected_bp = 0 if self.index_links and self.db_links and self.db_links.db_file and self.db_links.db_file.exists(): conn_sqlite = None try: # Use a synchronous connection for these stats queries, as they are complex. # The AsyncGfaDatabase is primarily for async inserts. conn_sqlite = sqlite3.connect(self.db_links.db_file.as_posix(), timeout=0.5) cursor = conn_sqlite.cursor() # PRAGMAs for read performance (WAL should be set by AsyncGfaDatabase) cursor.execute("PRAGMA journal_mode=WAL;") # Already set by writer cursor.execute("PRAGMA synchronous=OFF;") # Safer than OFF for reads if WAL cursor.execute("PRAGMA cache_size = -8000000;") # 2GB cache (page size * num_pages) cursor.execute("PRAGMA temp_store = MEMORY;") all_components: List[Set[str]] = [] visited_segments_for_cc = set() # Iterate over all known segment IDs to find all connected components for seg_id_start_cc in self.dict_segments_size.keys(): if seg_id_start_cc not in visited_segments_for_cc: component = self._get_connected_component(seg_id_start_cc, cursor) if component: all_components.append(component) visited_segments_for_cc.update(component) if all_components: num_connected_components = len(all_components) component_lengths_bp = [ sum(self.dict_segments_size.get(s, 0) for s in comp) for comp in all_components ] largest_cc_size_bp = max(component_lengths_bp) if component_lengths_bp else 0 # Find largest component by length to define "disconnected" relative to it largest_comp_idx = component_lengths_bp.index(largest_cc_size_bp) for i, comp_len in enumerate(component_lengths_bp): if i != largest_comp_idx: num_disconnected_components_count += 1 total_length_disconnected_bp += comp_len except sqlite3.Error as e: self.logger.error(f"SQLite error during connectivity stats: {e}", exc_info=True) finally: if conn_sqlite: conn_sqlite.close() else: self.logger.info( "Link indexing not enabled or DB not available; " "skipping database-dependent connectivity statistics." ) current_progress.update(stats_task_id, advance=1) # Step 6: Final simple stats from parsed data current_progress.update(stats_task_id, description="[cyan]Stats: Finalizing") # Dead ends: segments with degree 1 dead_ends_count = sum(1 for degree_val in self.degrees.values() if degree_val == 1) # Isolated segments: segments with degree 0 (from self.isolated_segments set) num_isolated_segments_count = len(self.isolated_segments) current_progress.update(stats_task_id, advance=1) # Consolidate all stats into a dictionary stats_results = { "Graph Overview": { "GFA File Name": self.gfa_name, "GFA Version": self.version, "Total Segments (S lines)": self.segment_count, "Total Links (L lines)": self.link_count, "Total Walks (W lines)": self.walks_count, "Unique Samples in Walks": len(self.dict_samples_chrom), }, "Segment Statistics": { "Total Segment Length (bp)": self.total_segment_length, # This is graph_size_bp "Average Segment Length (bp)": avg_segment_length, "Median Segment Length (bp)": median_node_length, "Avg Length of Top 5% Longest Segments (bp)": avg_top5_percent_length, "Median Length of Top 5% Longest Segments (bp)": median_top5_percent_length, "Segment Length Distribution": segment_length_bin_counts, }, "Link Statistics": { "Max Segment Degree": max_degree, "Average Segment Degree": avg_degree, "Self-Links (S1 -> S1)": self.self_links_count, "Inverted Links (S1+ -> S2-)": self.inverted_links_count, "Both Negative Links (S1- -> S2-)": self.negative_links_count, }, "Path (Walk) Statistics": { "Total Length of All Paths (bp, input_genome_size)": self.input_genome_size, "Graph Compression Ratio (input_genome_size / graph_size_bp)": compress_ratio, "Max Segments in a Single Walk": self.max_walk_rank, "Sum of First Segment Lengths in Walks": self.sum_rank0_length, }, "Segment Sharing & Depth": { "Avg Unique Samples per Segment (Similarity Mean)": sim_mean, "Median Unique Samples per Segment (Similarity Median)": sim_median, "StdDev Unique Samples per Segment (Similarity Std)": sim_std, "Avg Occurrences per Segment (Depth Mean)": depth_mean, "Median Occurrences per Segment (Depth Median)": depth_median, "StdDev Occurrences per Segment (Depth Std)": depth_std, }, "Graph Structure": { "Graph Density": graph_density, "Segments/Links Ratio": ne_ratio, "Dead-End Segments (degree 1)": dead_ends_count, "Isolated Segments (degree 0)": num_isolated_segments_count, "Number of Connected Components (CCs)": num_connected_components, "Largest CC Size (bp)": largest_cc_size_bp, "Number of Disconnected CCs (excluding largest)": num_disconnected_components_count, "Total Length of Disconnected CCs (bp)": total_length_disconnected_bp, }, } # Final progress update for formatting/saving current_progress.update(stats_task_id, description="[green]Stats: Formatting & Saving", advance=1) except Exception as e: self.logger.error(f"Error during statistics computation: {e}", exc_info=True) current_progress.update(stats_task_id, description=f"[red]Error computing stats: {e}") # stats_results will contain whatever was computed before the error. # Flatten the dictionary for DataFrame and CSV output flat_stats_for_df = [] for category, metrics_dict in stats_results.items(): if isinstance(metrics_dict, dict): for metric_name, value in metrics_dict.items(): flat_stats_for_df.append((category, metric_name, value)) else: # Should not happen with current stats_results structure flat_stats_for_df.append((category, "Value", metrics_dict)) df_stats = DataFrame(flat_stats_for_df, columns=["Category", "Metric", "Value"]) # Apply smart formatting for numbers in the 'Value' column before saving def format_value(x): if isinstance(x, dict): # For binned counts return ", ".join(f"{k}: {v}" for k, v in x.items()) if isinstance(x, float): if 0 < abs(x) < 1e-3 or abs(x) > 1e6: # Use scientific notation for very small or large floats return f"{x:.2e}" return f"{x:.2f}" # Standard float formatting if isinstance(x, int): return f"{x:,}" # Add comma for thousands separator return str(x) # Default to string for others df_stats["Value"] = df_stats["Value"].apply(format_value) if self.stats_file: try: df_stats.to_csv(self.stats_file, sep="\t", index=False) self.logger.info(f"Saved GFA statistics to '{self.stats_file.name}'") except IOError as e: self.logger.error(f"Failed to save statistics CSV '{self.stats_file}': {e}") return stats_results
[docs] @dataclass class GratoolsBam: """ Handles operations related to BAM files in the GraTools context, such as indexing, extracting segment information, tagging segments, and performing various analyses (core/dispensable ratio, depth statistics, etc.). Attributes ---------- bam_path : Path Path to the BAM file. threads : int, optional Number of threads for BAM operations (e.g., reading, indexing). Defaults to 1. logger : logging.Logger Logger instance. Defaults to a logger named "GraTools". suffix : Optional[str], optional Suffix to append to output filenames generated by analyses. Defaults to None. works_path : Optional[Path], optional Working directory path for saving output files. Defaults to None (uses BAM parent dir). gfa_name : Optional[str], optional Name of the associated GFA file (used for naming output files). Defaults to None. tagging : bool, optional If True, indicates that operations might modify tags, potentially requiring re-indexing. Used by `index_bam` to decide if indexing is needed. Defaults to False. progress : Optional[Progress] Rich Progress instance for displaying progress. Auto-initialized. """ bam_path: Path threads: int = 1 logger: logging.Logger = field( default_factory=lambda: logging.getLogger("GraTools") ) suffix: Optional[str] = None # For output file naming works_path: Optional[Path] = None # Directory for outputs gfa_name: Optional[str] = None # For naming consistency tagging: bool = False # If operations involve modifying tags, affects indexing logic # Internal attribute for progress bar progress: Optional[Progress] = field(init=False, repr=False) def __post_init__(self) -> None: """Initialize progress tracker and index the BAM file if necessary.""" self.progress = Progress( TextColumn("[bold blue]{task.description}"), BarColumn(), "[progress.percentage]{task.percentage:>3.1f}%", TextColumn("{task.completed}/{task.total} segs"), TimeElapsedColumn(), TimeRemainingColumn(), refresh_per_second=1, # Reduced refresh rate for potentially long BAM ops transient=True, ) if not self.bam_path.exists(): self.logger.error(f"BAM file does not exist: {self.bam_path}") # raise FileNotFoundError(f"BAM file not found: {self.bam_path}") # Or handle gracefully return self.index_bam() # Index on initialization if needed # Ensure works_path is set, default to BAM's parent directory if not provided if self.works_path is None: self.works_path = self.bam_path.parent self.works_path.mkdir(parents=True, exist_ok=True) # Ensure it exists # Ensure gfa_name is set, derive from bam_path if not provided if self.gfa_name is None: # Try to infer from BAM name, e.g., if bam is "mygfa.bam", gfa_name is "mygfa" self.gfa_name = self.bam_path.name.removesuffix(".bam") # Ensure suffix is an empty string if None, for consistent filename concatenation if self.suffix is None: self.suffix = ""
[docs] def index_bam(self) -> None: """ Indexes the BAM file using `pysam.index` if the index is missing or outdated. The index file will have a '.bai' or '.crai' extension depending on BAM format. """ if not self.bam_path.exists(): self.logger.warning(f"Cannot index BAM file: '{self.bam_path}' does not exist.") return # Pysam handles .bai for BAM and .crai for CRAM automatically. # Common index extension for BAM is .bai. bam_index_path_bai = self.bam_path.with_suffix(self.bam_path.suffix + ".bai") # e.g. file.bam.bai bam_index_path_crai = self.bam_path.with_suffix(self.bam_path.suffix + ".crai") # e.g. file.cram.crai # Determine if index exists (either .bai or .crai) index_exists = bam_index_path_bai.exists() or bam_index_path_crai.exists() actual_index_path = bam_index_path_bai if bam_index_path_bai.exists() else bam_index_path_crai needs_indexing = False if not index_exists and self.tagging: self.logger.info( f"Indexing BAM file '{self.bam_path.name}': No existing index found." ) needs_indexing = True # If an index exists, check if BAM file is newer (mtime check) # This check only makes sense if `self.tagging` is True, implying the BAM might have changed. # Otherwise, if BAM is static, this check is less relevant after first indexing. elif self.tagging and actual_index_path.exists() and \ self.bam_path.stat().st_mtime > actual_index_path.stat().st_mtime: self.logger.info( f"Re-indexing BAM file '{self.bam_path.name}': BAM file is newer than its index." ) needs_indexing = True if needs_indexing: try: # pysam.index will create file.bam.bai or file.cram.crai index(self.bam_path.as_posix(), threads=self.threads) self.logger.info(f"Successfully indexed BAM file: '{self.bam_path.name}'.") except Exception as e: # pysam can raise various errors (RuntimeError, ValueError) self.logger.error(f"Error indexing BAM file '{self.bam_path.name}': {e}", exc_info=True) else: self.logger.debug(f"BAM file '{self.bam_path.name}' index is up-to-date or tagging is false.")
def _get_total_segments_in_bam(self) -> int: """ Counts the total number of alignments (segments) in the BAM file. Returns ------- int The total number of segments. Returns 0 if file cannot be read or is empty. """ if not self.bam_path.exists(): return 0 try: with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file: total_segments = bam_file.count(until_eof=True) # More direct pysam way return total_segments except Exception as e: self.logger.error(f"Error counting segments in BAM '{self.bam_path.name}': {e}") return 0
[docs] def build_segments(self, list_segments: Optional[List[str]] = None) -> Tuple[List[str], defaultdict, defaultdict]: """ Extracts specified segments from a BAM file and reconstructs their GFA S-line representation. Also populates dictionaries for segment samples and sequences. Parameters ---------- list_segments : Optional[List[str]], optional A list of segment IDs (query_name in BAM) to extract. If None or empty, this method might process all segments or return empty results, depending on intended behavior (current pysam.view call implies it needs a list). Returns ------- Tuple[List[str], defaultdict[str, List[str]], defaultdict[str, str]] - gfa_s_lines_list: List of strings, each a GFA S-line. - dict_seg_samples: defaultdict mapping segment ID to list of "sample;chrom;haplo" strings from SW tag. - dict_seg_sequence: defaultdict mapping segment ID to its sequence. """ gfa_s_lines_list: List[str] = [] dict_seg_samples: defaultdict[str, List[str]] = defaultdict(list) dict_seg_sequence: defaultdict[str, str] = defaultdict(str) if not list_segments: self.logger.warning("Build_segments called with no segment IDs. Returning empty results.") return gfa_s_lines_list, dict_seg_samples, dict_seg_sequence # Temporary files are used to pass the list of segment IDs to `pysam view`. # OPTIMIZATION_SUGGESTION: If `list_segments` is very large, writing to temp file is fine. # For smaller lists, iterating the BAM in Python and filtering by `segment.query_name in set(list_segments)` # might avoid `pysam view` subprocess overhead, but `pysam view` is highly optimized. # The choice depends on typical `list_segments` size and BAM size. # Ensure bam_path's parent exists for temp files if works_path isn't explicitly set for temp dir temp_dir = self.works_path if self.works_path else self.bam_path.parent try: with tempfile.NamedTemporaryFile( mode="w+", dir=temp_dir, suffix=".txt", delete=False ) as list_seg_file, tempfile.NamedTemporaryFile( mode="wb", dir=temp_dir, suffix=".bam", delete=False ) as bam_temp_file: temp_seg_list_path = Path(list_seg_file.name) temp_bam_path = Path(bam_temp_file.name) # Write segment IDs to the temporary text file for seg_id in list_segments: list_seg_file.write(f"{seg_id}\n") list_seg_file.flush() # Ensure data is written before pysam view reads it self.logger.debug( f"Extracting {len(list_segments)} segments using pysam view. List: {temp_seg_list_path}, Temp BAM: {temp_bam_path}") # Use pysam.view to extract these segments into a temporary BAM file. # -N option: read names in file (the list_seg_file.name) # -hb option: output BAM, include header # The input BAM is self.bam_path # Output is stdout, captured as bytes # REFACTOR_SUGGESTION: pysam.view can output directly to a filename: # pysam.view("-hbN", temp_seg_list_path.as_posix(), "-o", temp_bam_path.as_posix(), self.bam_path.as_posix(), catch_stdout=False) # This avoids reading all output into memory if it's large. bam_view_output_bytes = view("-hbN", temp_seg_list_path.as_posix(), self.bam_path.as_posix()) bam_temp_file.write(bam_view_output_bytes) bam_temp_file.flush() # Read the temporary BAM file containing only the desired segments with AlignmentFile( temp_bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file_filtered: for segment in bam_file_filtered.fetch(until_eof=True): seg_id = segment.query_name sequence = segment.query_sequence # Reconstruct GFA S-line. SU tag contains GFA optional tags. su_tag_val = segment.get_tag('SU') if segment.has_tag('SU') else "" gfa_s_line = f"S\t{seg_id}\t{sequence}\t{su_tag_val}" gfa_s_lines_list.append(gfa_s_line) # Populate segment samples dictionary from SW tag if segment.has_tag('SW'): sw_tag_val = segment.get_tag('SW') # SW tag format: "sample1;chrom1;hap1,sample2;chrom2;hap2,..." # The split by ';'[:-1] seems to intend removing a trailing empty string if # the tag ends with a comma/semicolon. Robust split: dict_seg_samples[seg_id] = [s.strip() for s in sw_tag_val.split(',') if s.strip()] # Populate segment sequence dictionary dict_seg_sequence[seg_id] = sequence return gfa_s_lines_list, dict_seg_samples, dict_seg_sequence except Exception as e: self.logger.error(f"Error in build_segments: {e}", exc_info=True) return [], defaultdict(list), defaultdict(str) # Return empty on error finally: # Clean up temporary files if 'temp_seg_list_path' in locals() and temp_seg_list_path.exists(): temp_seg_list_path.unlink() if 'temp_bam_path' in locals() and temp_bam_path.exists(): temp_bam_path.unlink()
[docs] def tag(self, dict_segments_samples: Dict[str, List[str]]) -> Path: """ Adds or updates the 'SW' (Sample Walks) tag to segments in the BAM file. The SW tag stores a comma-separated list of "sample;chromosome;haplotype" strings indicating which walks/paths contain the segment. The original BAM file is overwritten with the tagged version. Parameters ---------- dict_segments_samples : Dict[str, List[str]] Dictionary mapping segment IDs (query_name) to a list of "sample;chromosome;haplotype" strings. Returns ------- Path The path to the (now tagged and re-indexed) BAM file. Raises ------ FileNotFoundError If the input BAM file does not exist. Exception If errors occur during BAM reading, writing, or renaming. """ if not self.bam_path.exists(): self.logger.error(f"Input BAM file for tagging not found: {self.bam_path}") raise FileNotFoundError(f"Input BAM file for tagging not found: {self.bam_path}") # Path for the temporary tagged BAM file temp_tagged_bam_path = self.bam_path.with_name(f"{self.bam_path.stem}_tagged_temp.bam") num_segments_to_tag = len(dict_segments_samples) self.logger.info( f"Tagging {num_segments_to_tag} segments in BAM file: '{self.bam_path.name}'. " ) try: with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file_in, AlignmentFile( temp_tagged_bam_path.as_posix(), "wb", # Write BAM template=bam_file_in, # Use header from input BAM threads=self.threads, ) as bam_file_out: # Progress bar for tagging total_segments_in_bam = self._get_total_segments_in_bam() # Efficient count tag_task_id = self.progress.add_task( f"Tagging BAM '{self.bam_path.name}'", total=total_segments_in_bam ) with self.progress: # Context manager for Rich progress for segment in bam_file_in.fetch(until_eof=True): self.progress.advance(tag_task_id) segment_id = segment.query_name if segment_id in dict_segments_samples: # Construct SW tag string: "samp1;chr1;hap1,samp2;chr2;hap2,..." sw_tag_value = ",".join(dict_segments_samples[segment_id]) # Set SW tag (string type 'Z'). Replace if it exists. segment.set_tag("SW", sw_tag_value, value_type='Z', replace=True) bam_file_out.write(segment) # Replace original BAM with the tagged BAM temp_tagged_bam_path.replace(self.bam_path) # Atomically replaces if on same filesystem self.logger.info(f"BAM file '{self.bam_path.name}' successfully tagged.") # Mark that tagging occurred, so index_bam can re-index self.tagging = True self.index_bam() # Re-index the newly tagged BAM file except Exception as e: self.logger.error(f"Error during BAM tagging for '{self.bam_path.name}': {e}", exc_info=True) # Clean up temporary file if it exists on error if temp_tagged_bam_path.exists(): temp_tagged_bam_path.unlink(missing_ok=True) raise # Re-raise the exception after cleanup attempt return self.bam_path
[docs] def core_dispensable_ratio( self, nb_samples_gfa: int, input_as_number: bool, shared_min_cutoff: int = 1, specific_max_cutoff: Optional[int] = None, # can be None filter_min_len: int = 1, ) -> None: """ Analyzes segments in the BAM file to determine core (shared) and dispensable (specific) ratios based on the number of samples a segment is found in. Saves results to a CSV file. Parameters ---------- nb_samples_gfa : int Total number of unique samples present in the GFA (used for percentage calculation). input_as_number : bool If True, `shared_min_cutoff` and `specific_max_cutoff` are treated as absolute counts of samples. If False, they are treated as percentages of `nb_samples_gfa`. shared_min_cutoff : int, optional Minimum number/percentage of samples a segment must be in to be considered "shared" (core). Defaults to 1. specific_max_cutoff : Optional[int], optional Maximum number/percentage of samples a segment can be in to be considered "specific" (dispensable). If None, specific analysis might be skipped or use a default (e.g., 1 if input_as_number). Defaults to None. filter_min_len : int, optional Minimum length (bp) for a segment to be included in the filtered analysis. Defaults to 1 (no length filter). """ # Validate and adjust cutoffs based on input_as_number if not input_as_number: # Percentages shared_min_abs = round((shared_min_cutoff * nb_samples_gfa) / 100) specific_max_abs = round(( specific_max_cutoff * nb_samples_gfa) / 100) if specific_max_cutoff is not None else 1 # Default to 1 if specific_max is None in % mode self.logger.info( f"Cutoffs (percentage input): Shared_min={shared_min_cutoff}% ({shared_min_abs} samples), " f"Specific_max={specific_max_cutoff}% ({specific_max_abs} samples if specified)." ) else: # Absolute numbers shared_min_abs = shared_min_cutoff specific_max_abs = specific_max_cutoff if specific_max_cutoff is not None else 1 # Default for absolute self.logger.info( f"Cutoffs (absolute number input): Shared_min={shared_min_abs} samples, " f"Specific_max={specific_max_abs} samples if specified." ) if specific_max_abs is not None and shared_min_abs < specific_max_abs: self.logger.warning( # Warning, not error, as it might be intentional for some definitions f"shared_min_cutoff ({shared_min_abs}) is less than specific_max_cutoff ({specific_max_abs}). " "This means some segments could be classified as both shared and specific. " "Ensure this is the intended definition." ) if shared_min_abs > nb_samples_gfa or \ (specific_max_abs is not None and specific_max_abs > nb_samples_gfa): self.logger.error( f"Cutoff values exceed total GFA samples ({nb_samples_gfa}). " f"Shared_min={shared_min_abs}, Specific_max={specific_max_abs}. Adjust parameters." ) return self.logger.info( f"Analyzing core/dispensable ratio for {nb_samples_gfa} GFA samples. " f"Effective cutoffs: Shared_min >= {shared_min_abs}, Specific_max <= {specific_max_abs}, Min_Length_Filter = {filter_min_len}bp." ) category_counts = Counter() # Keys: "shared", "specific", "shared_filtered", "specific_filtered", "length_filtered_out" # dico_count_samples maps num_samples_per_segment -> count_of_such_segments total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: self.logger.warning("BAM file is empty or unreadable. Cannot perform core/dispensable analysis.") return core_disp_task_id = self.progress.add_task( "Core/Dispensable Ratio Analysis", total=total_segments_in_bam ) with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(core_disp_task_id) seg_length = segment.query_length if not segment.has_tag("SW"): self.logger.warning( f"Segment '{segment.query_name}' missing SW tag. Cannot determine sample count. Skipping.") category_counts["missing_sw_tag"] += 1 continue sw_tag_value = segment.get_tag("SW") # Samples are "sample_id;chrom;haplotype", extract unique sample_ids unique_samples_for_segment = set(s.split(";")[0] for s in sw_tag_value.split(',') if s) num_samples_for_this_segment = len(unique_samples_for_segment) # Classify segment (unfiltered by length first) is_shared = num_samples_for_this_segment >= shared_min_abs is_specific = specific_max_abs is not None and num_samples_for_this_segment <= specific_max_abs if is_shared: category_counts["shared_raw"] += 1 if is_specific: # Note: can be both shared and specific if cutoffs overlap category_counts["specific_raw"] += 1 # Apply length filter if seg_length >= filter_min_len: # Segment passes length filter if is_shared: category_counts["shared_len_filtered"] += 1 if is_specific: category_counts["specific_len_filtered"] += 1 else: # Segment does not pass length filter category_counts["length_filtered_out"] += 1 # Calculate percentages total_segments_processed = total_segments_in_bam - category_counts["missing_sw_tag"] segments_passing_length_filter = total_segments_processed - category_counts["length_filtered_out"] results_data = [] if total_segments_processed > 0: results_data.extend([ ("Shared (Core) - Raw", category_counts["shared_raw"], total_segments_processed, (category_counts["shared_raw"] / total_segments_processed) * 100), ("Specific (Dispensable) - Raw", category_counts["specific_raw"], total_segments_processed, (category_counts["specific_raw"] / total_segments_processed) * 100), ]) if segments_passing_length_filter > 0: results_data.extend([ (f"Shared (Core) - Filtered (Length >= {filter_min_len}bp)", category_counts["shared_len_filtered"], segments_passing_length_filter, (category_counts["shared_len_filtered"] / segments_passing_length_filter) * 100), (f"Specific (Dispensable) - Filtered (Length >= {filter_min_len}bp)", category_counts["specific_len_filtered"], segments_passing_length_filter, (category_counts["specific_len_filtered"] / segments_passing_length_filter) * 100), ]) results_data.append( ("Segments Filtered Out by Length", category_counts["length_filtered_out"], total_segments_processed, ( category_counts[ "length_filtered_out"] / total_segments_processed) * 100 if total_segments_processed > 0 else 0)) if category_counts["missing_sw_tag"] > 0: results_data.append( ("Segments Missing SW Tag (Skipped)", category_counts["missing_sw_tag"], total_segments_in_bam, (category_counts["missing_sw_tag"] / total_segments_in_bam) * 100)) df_results = DataFrame(results_data, columns=["Category", "Count", "Total Relevant Segments", "Percentage"]) # Define output CSV file path # Suffix construction needs care if specific_max_cutoff is None spec_max_str = str(specific_max_cutoff) if specific_max_cutoff is not None else "N_A" csv_filename = ( f"{self.gfa_name}_core_disp_ratio" f"_shared{shared_min_cutoff}{'pct' if not input_as_number else ''}" f"_spec{spec_max_str}{'pct' if not input_as_number else ''}" f"_len{filter_min_len}{self.suffix}.csv" ) output_csv_path = self.works_path / csv_filename try: df_results.to_csv(output_csv_path, index=False, float_format="%.2f") self.logger.info(f"Core/Dispensable ratio analysis saved to: {output_csv_path}") # Summary shared_console.rule("[bold green]Summary[/bold green]") shared_console.print( f"[cyan]Total segments in GFA:[/cyan] {total_segments_in_bam:,}" ) shared_console.print( f"[cyan]Total segments analyzed:[/cyan] {total_segments_processed:,}" ) shared_console.print( f"[cyan]Total segments passing length filter (≥ {filter_min_len}bp):[/cyan] {segments_passing_length_filter:,} " f"({(segments_passing_length_filter / total_segments_processed) * 100:.2f}%)" ) # Display formatted table using Rich table = Table( title=f"[bold magenta]Core vs. Dispensable Segments[/bold magenta] — [cyan]{self.gfa_name}[/cyan]", box=box.ROUNDED, show_lines=False, ) table.add_column("Category", style="bold cyan", justify="left") table.add_column("Count", style="green", justify="right") table.add_column("Total", style="yellow", justify="right") table.add_column("Percentage", style="bold green", justify="right") for _, row in df_results.iterrows(): table.add_row( str(row["Category"]), f"{int(row['Count']):,}", f"{int(row['Total Relevant Segments']):,}", f"{row['Percentage']:.2f}%" ) shared_console.print(table) except IOError as e: self.logger.error(f"Failed to save Core/Dispensable ratio CSV to '{output_csv_path}': {e}")
[docs] def depth_nodes_stat(self, nb_samples_gfa: int, filter_min_len: int = 1) -> None: """ Calculates and displays statistics about segment depth (number of unique samples a segment is found in). Outputs results to console and a CSV file. Parameters ---------- nb_samples_gfa : int Total number of unique samples in the GFA, used for context if needed (not directly in calcs here). filter_min_len : int, optional Minimum length (bp) for a segment to be included in the filtered depth analysis. Defaults to 1 (no effective length filter). """ self.logger.info( f"Calculating node depth statistics for GFA with {nb_samples_gfa} total samples. " f"Length filter for 'Filtered_Counts': >= {filter_min_len}bp." ) # depth_counts_raw: maps segment_depth (num_samples) -> count of segments with this depth depth_counts_raw = defaultdict(int) # depth_counts_filtered: similar, but only for segments passing length filter depth_counts_filtered = defaultdict(int) # Other counters from original code, ensure their purpose is clear or remove if unused by this method's output # counter = Counter() # Was: "shared_filtered", "specific_filtered", "filtered" - seems mixed with core_dispensable logic total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: self.logger.warning("BAM file is empty or unreadable. Cannot perform node depth analysis.") return depth_stat_task_id = self.progress.add_task( "Node Depth Statistics Analysis", total=total_segments_in_bam ) segments_missing_sw_tag = 0 with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(depth_stat_task_id) if not segment.has_tag("SW"): segments_missing_sw_tag += 1 continue # Skip segments without SW tag for depth calculation sw_tag_value = segment.get_tag("SW") unique_samples_for_segment = set(s.split(";")[0] for s in sw_tag_value.split(',') if s) current_segment_depth = len(unique_samples_for_segment) depth_counts_raw[current_segment_depth] += 1 seg_length = segment.query_length if seg_length >= filter_min_len: depth_counts_filtered[current_segment_depth] += 1 if segments_missing_sw_tag > 0: self.logger.warning(f"{segments_missing_sw_tag} segments were skipped due to missing SW tag.") # Prepare data for DataFrame df_raw = DataFrame(list(depth_counts_raw.items()), columns=["Segment_Depth", "Raw_Count"]) df_filtered = DataFrame(list(depth_counts_filtered.items()), columns=["Segment_Depth", f"Filtered_Count_Len>={filter_min_len}"]) # Merge the two dataframes on Segment_Depth df_stats = merge(df_raw, df_filtered, on="Segment_Depth", how="outer").fillna(0) # Calculate percentages total_raw_segments_counted = df_stats["Raw_Count"].sum() total_filtered_segments_counted = df_stats[f"Filtered_Count_Len>={filter_min_len}"].sum() if total_raw_segments_counted > 0: df_stats["Raw_Percentage"] = (df_stats["Raw_Count"] / total_raw_segments_counted) * 100 else: df_stats["Raw_Percentage"] = 0.0 if total_filtered_segments_counted > 0: df_stats[f"Filtered_Percentage_Len>={filter_min_len}"] = \ (df_stats[f"Filtered_Count_Len>={filter_min_len}"] / total_filtered_segments_counted) * 100 else: df_stats[f"Filtered_Percentage_Len>={filter_min_len}"] = 0.0 # Sort by Segment_Depth df_stats_sorted = df_stats.sort_values(by="Segment_Depth").reset_index(drop=True) # Convert counts to int type for cleaner display df_stats_sorted["Raw_Count"] = df_stats_sorted["Raw_Count"].astype(int) df_stats_sorted[f"Filtered_Count_Len>={filter_min_len}"] = df_stats_sorted[ f"Filtered_Count_Len>={filter_min_len}"].astype(int) # Output to CSV csv_filename = f"{self.gfa_name}_node_depth_stats_len{filter_min_len}{self.suffix}.csv" output_csv_path = self.works_path / csv_filename try: df_stats_sorted.to_csv(output_csv_path, index=False, float_format="%.2f") self.logger.info(f"Node depth statistics saved to: {output_csv_path}") except IOError as e: self.logger.error(f"Failed to save node depth statistics CSV to '{output_csv_path}': {e}") try: # Dynamic title table = Table( title=f"[bold magenta]Node Depth Statistics[/bold magenta] — [cyan]{self.gfa_name}[/cyan] (Len ≥ {filter_min_len}bp)", box=box.ROUNDED, ) # Add columns with styles and justifications table.add_column("Depth", style="bold cyan", justify="center") table.add_column("Segments", style="green", justify="center") table.add_column("Percentage", style="green", justify="center") table.add_column("Filtered Segments", style="yellow", justify="center") table.add_column("Filtered %", style="yellow", justify="center") # Get the correct filtered column names (depends on the filter) filtered_count_col = f"Filtered_Count_Len>={filter_min_len}" filtered_pct_col = f"Filtered_Percentage_Len>={filter_min_len}" if total_raw_segments_counted > 0: filtered_pct = (total_filtered_segments_counted / total_raw_segments_counted) * 100 else: filtered_pct = 0.0 # Summary above the table shared_console.rule("\n[green]Summary[/green]") shared_console.print( f"[cyan]Total segments analyzed:[/cyan] {total_raw_segments_counted:,}" ) shared_console.print( f"[cyan]Total segments passing length filter:[/cyan] {total_filtered_segments_counted:,} " f"({filtered_pct:.2f}%)\n" ) # Add data rows for _, row in df_stats_sorted.iterrows(): table.add_row( f"{row['Segment_Depth']:.0f}", f"{int(row['Raw_Count']):,}", f"{row['Raw_Percentage']:.2f}%", f"{int(row[filtered_count_col]):,}", f"{row[filtered_pct_col]:.2f}%" ) # Print the table shared_console.print(table) except Exception as e: print(e) self.logger.error(e, exc_info=True)
[docs] def get_specific_and_shared_segments( self, samples_list_A: List[str], samples_list_B: Optional[List[str]] = None, filter_min_len: Optional[int] = None, output_csv: Optional[bool] = None, ) -> Tuple[Set[str], Set[str]]: """ Identifies and counts segments that are shared among one group of samples and specific to that group relative to a second group. Parameters ---------- samples_list_A : List[str] A list of sample names. A segment is "shared" if it is present in ALL samples in this list. samples_list_B : Optional[List[str]], optional An optional second list of sample names. If provided, a segment is "specific" if it is shared by all in `samples_list_A` AND absent from ALL samples in this list. filter_min_len : Optional[int], optional If set, only segments with a length greater than or equal to this value will be considered. output_csv : Optional[bool], optional If True, the function will return sets of the shared and specific segment IDs. Returns ------- Tuple[Set[str], Set[str]] - A set of segment IDs that are shared by all samples in `samples_list_A`. - A set of segment IDs that are specific to `samples_list_A` relative to `samples_list_B`. (This set is a subset of the first one). """ if not samples_list_A: self.logger.error("samples_list_A cannot be empty for specific/shared segment analysis.") return set(), set() set_A = set(samples_list_A) set_B = set(samples_list_B) if samples_list_B else set() segment_list_shared: Set[str] = set() segment_list_specific: Set[str] = set() self.logger.info(f"Analyzing segments: Shared by ALL in {samples_list_A}.") if samples_list_B: self.logger.info( f"Also finding segments specific to list A (and absent from ALL in {samples_list_B}).") if filter_min_len is not None: self.logger.info(f"Applying length filter: >= {filter_min_len}bp.") counts = Counter() # "shared_A_count", "shared_A_length", "specific_to_A_count", "specific_to_A_length" total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: self.logger.error("BAM file empty/unreadable. Cannot perform specific/shared analysis.") return set(), set() spec_shared_task_id = self.progress.add_task( "Specific/Shared Segment Analysis", total=total_segments_in_bam ) segments_missing_sw_tag = 0 with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(spec_shared_task_id) if not segment.has_tag("SW"): segments_missing_sw_tag += 1 continue seg_length = segment.query_length if filter_min_len is not None and seg_length < filter_min_len: continue # Skip segment if it doesn't meet length criteria sw_tag_value = segment.get_tag("SW") samples_in_segment_gfa = set(s.split(";")[0] for s in sw_tag_value.split(',') if s) # Check if segment is shared by all samples in set_A if set_A.issubset(samples_in_segment_gfa): counts["shared_A_count"] += 1 counts["shared_A_length"] += seg_length if output_csv: segment_list_shared.add(segment.query_name) # If set_B is provided, check for specificity (present in A, absent in B) if samples_list_B: # samples_list_B implies set_B is not empty # isdisjoint checks if set_B and samples_in_segment_gfa have no common elements if set_B.isdisjoint(samples_in_segment_gfa): counts["specific_to_A_count"] += 1 counts["specific_to_A_length"] += seg_length if output_csv: segment_list_specific.add(segment.query_name) if segments_missing_sw_tag > 0: self.logger.warning(f"{segments_missing_sw_tag} segments were skipped due to missing SW tag.") total_segments_analyzed = total_segments_in_bam - segments_missing_sw_tag # 1. Prepare the data (calculations are the same) total_segments_analyzed = max(1, total_segments_analyzed) # Avoid division by zero percent_shared_A = (counts["shared_A_count"] / total_segments_analyzed) * 100 if samples_list_B: percent_specific_to_A = (counts["specific_to_A_count"] / total_segments_analyzed) * 100 # 2. Create an inner table for perfect alignment # We use a borderless, headerless table to simulate a key-value list. results_table = Table(box=None, show_header=False, expand=True) results_table.add_column("Metric", style="cyan", no_wrap=True) results_table.add_column("Value", justify="right") # 3. Add information about shared segments # Using box-drawing characters for a "tree" effect results_table.add_row(f"[bold]Segments shared by {len(samples_list_A)} in list {samples_list_A}[/bold]", "") results_table.add_row(" ├─ Count", f"{counts['shared_A_count']:,} / {total_segments_analyzed:,}") results_table.add_row(" ├─ Percentage", f"[bold green]{percent_shared_A:.2f}%[/bold green]") results_table.add_row(" └─ Total Length", f"{counts['shared_A_length']:,} bp") # 4. Add information about specific segments (if applicable) if samples_list_B: # An empty row for visual spacing results_table.add_row() results_table.add_row(f"[bold]Segments specific to {len(samples_list_A)} in list {samples_list_A}[/bold] and absent in {len(samples_list_B)} in list {samples_list_B}", "") results_table.add_row(" ├─ Count", f"{counts['specific_to_A_count']:,} / {total_segments_analyzed:,}") results_table.add_row(" ├─ Percentage", f"[bold yellow]{percent_specific_to_A:.2f}%[/bold yellow]") results_table.add_row(" └─ Total Length", f"{counts['specific_to_A_length']:,} bp") # 5. Create a Panel to wrap everything and display it summary_panel = Panel( results_table, title="[bold blue]📊 Shared & Specific Segment Analysis[/bold blue]", border_style="blue", expand=False # The panel will fit its content ) shared_console.print(summary_panel) return segment_list_shared, segment_list_specific
[docs] def get_segments_and_positions_by_depth( self, total_gfa_samples: int, input_as_number: bool, lower_bound_depth: int, upper_bound_depth: int, filter_min_len: int, bed_path: Path = None, ) -> Tuple[Dict[str, int], Dict[str, Dict[str, List[Tuple[str, int, int]]]]]: """ Finds segments within a specific depth range and retrieves their genomic positions from BED files. This function performs two main steps: 1. Scans the BAM file to identify segments that meet the specified depth and length criteria. 2. For those segments, it efficiently queries the relevant BED files to find their exact genomic coordinates (chromosome, start, end). Parameters ---------- total_gfa_samples : int Total number of unique samples in the GFA, used for percentage calculations. input_as_number : bool If True, depth bounds are absolute counts; if False, they are percentages of `total_gfa_samples`. lower_bound_depth : int Minimum sample depth (count or percentage) for a segment to be included. upper_bound_depth : int Maximum sample depth (count or percentage) for a segment to be included. filter_min_len : int Minimum length in base pairs for a segment to be considered. bed_path : Path, optional Path to the directory containing the sample-specific BED files. Returns ------- Tuple[Dict[str, int], Dict[str, Dict[str, List[Tuple[str, int, int]]]]] A tuple containing two dictionaries: 1. segments_with_depth: {segment_id: depth} for all segments matching the criteria. 2. segment_locations: {segment_id: {sample_name: [(chrom, start, end), ...]}} """ # --- PART 1: Find segments in the BAM file (your new, excellent logic) --- if bed_path: self.bed_path = bed_path if not input_as_number: abs_lower_bound = round((lower_bound_depth * total_gfa_samples) / 100) abs_upper_bound = round((upper_bound_depth * total_gfa_samples) / 100) self.logger.info( f"Depth range (percentage): {lower_bound_depth}%-{upper_bound_depth}% -> " f"Effective count range: [{abs_lower_bound}, {abs_upper_bound}]." ) else: abs_lower_bound = lower_bound_depth abs_upper_bound = upper_bound_depth self.logger.info(f"Depth range (absolute): [{abs_lower_bound}, {abs_upper_bound}] samples.") if abs_lower_bound > abs_upper_bound: self.logger.error(f"Lower bound ({abs_lower_bound}) > upper bound ({abs_upper_bound}). Aborting.") return {}, {} self.logger.info(f"Applying minimum length filter: >= {filter_min_len} bp.") total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: return {}, {} segments_with_depth: Dict[str, int] = {} samples_to_scan: Set[str] = set() # We collect the relevant samples bam_task_id = self.progress.add_task( f"Scanning BAM for segments in depth range [{abs_lower_bound}-{abs_upper_bound}]", total=total_segments_in_bam ) with AlignmentFile(self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(bam_task_id) if segment.query_length < filter_min_len or not segment.has_tag("SW"): continue unique_samples = set(s.split(";")[0] for s in segment.get_tag("SW").split(',') if s) depth = len(unique_samples) if abs_lower_bound <= depth <= abs_upper_bound: segments_with_depth[segment.query_name] = depth samples_to_scan.update(unique_samples) # We add the samples to be scanned self.logger.info(f"Found {len(segments_with_depth)} segments meeting criteria. Now finding their positions.") # --- PART 2: Find positions for the found segments if not segments_with_depth: return {}, {} segment_locations = self._find_positions_for_segments( segment_ids=set(segments_with_depth.keys()), sample_list=list(samples_to_scan) ) return segments_with_depth, segment_locations
def _find_positions_for_segments(self, segment_ids: Set[str], sample_list: List[str]) -> Dict[ str, Dict[str, List[Tuple[str, int, int]]]]: """ Efficiently finds all genomic positions for a given set of segments across a list of samples. Uses a streaming awk command for maximum performance and low memory usage. """ locations = defaultdict(lambda: defaultdict(list)) with tempfile.NamedTemporaryFile(mode='w+', delete=True, suffix="_segment_ids.txt") as tmp_file: tmp_file.write('\n'.join(segment_ids)) tmp_file.flush() segment_id_filepath = tmp_file.name def _process_sample_bed(sample_name): sample_bed_path = self.bed_path / f"{sample_name}.bed" if not sample_bed_path.exists(): return {} awk_script = 'NR==FNR{ids[$1]; next} {id=$4; sub(/^[+-]/,"",id); if(id in ids) print $0}' cmd = ['awk', awk_script, segment_id_filepath, str(sample_bed_path)] sample_results = defaultdict(lambda: defaultdict(list)) try: with subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1) as process: for line in process.stdout: parts = line.strip().split('\t') if len(parts) < 4: continue seg_id = parts[3].strip("+-") # We store the positions in the desired tuple format pos_tuple = (parts[0], int(parts[1]), int(parts[2])) sample_results[seg_id][sample_name].append(pos_tuple) if process.returncode != 0: self.logger.error(f"Awk failed for '{sample_name}' (code {process.returncode}).") except Exception as e: self.logger.error(f"Awk execution failed for '{sample_name}': {e}", extra={"markup": False}) return sample_results # Run the processing in parallel with ThreadPoolExecutor(max_workers=self.threads) as executor, self.progress: bed_task_id = self.progress.add_task("Querying BED files for positions", total=len(sample_list)) futures = {executor.submit(_process_sample_bed, sample): sample for sample in sample_list} for future in as_completed(futures): partial_result = future.result() for seg_id, sample_map in partial_result.items(): for sample_name, pos_list in sample_map.items(): locations[seg_id][sample_name].extend(pos_list) self.progress.advance(bed_task_id) return locations
[docs] def export_nodes_to_csv( self, output_csv_path: Path, long_node_length_threshold: int = 1000 ) -> None: """ Exports information about each segment (node) in the BAM file to a CSV file. Includes node name, length, sample IDs (from SW tag), inferred direction, and a flag if it's a "long" node. Parameters ---------- output_csv_path : Path Path where the output CSV file will be saved. long_node_length_threshold : int, optional Length threshold (bp) to classify a node as "long". Defaults to 1000. """ self.logger.info(f"Exporting node information to CSV: {output_csv_path}") nodes_data_list: List[Dict[str, Any]] = [] total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: self.logger.warning("BAM file empty/unreadable. Cannot export nodes.") return export_task_id = self.progress.add_task( "Exporting Nodes to CSV", total=total_segments_in_bam ) segments_missing_sw_tag = 0 with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(export_task_id) node_name = segment.query_name node_length = segment.query_length # Get Sample IDs from SW tag if segment.has_tag("SW"): sw_tag_value = segment.get_tag("SW") # Format for CSV: semicolon-separated if multiple samples in one string. # Original SW tag is comma-separated "sample;chr;hap". # If this should be a list of samples, adjust formatting. # Assuming the intent is to keep the raw SW tag value or a slightly cleaned version. formatted_samples_str = sw_tag_value # Or: ";".join(sw_tag_value.split(',')) else: segments_missing_sw_tag += 1 formatted_samples_str = "N/A" # Placeholder if SW tag is missing # Infer direction from node name (e.g., if "nodeA-" implies reverse) # This is a heuristic. GFA S lines themselves don't have orientation. # Orientation is context-dependent (in L lines, P lines, W lines). # If BAM query_name stores oriented name from GFA path, this might work. # Otherwise, direction for an isolated S line is ambiguous. # Assuming query_name might be like "seg1+" or "seg1-". direction = "N/A" # Default if no clear indicator if node_name.endswith("+"): direction = "+" # node_name_cleaned = node_name[:-1] # If you want to store unoriented name elif node_name.endswith("-"): direction = "-" # node_name_cleaned = node_name[:-1] is_long_node_str = "Yes" if node_length > long_node_length_threshold else "No" nodes_data_list.append({ "Node_Name": node_name, # Or node_name_cleaned "Length_bp": node_length, "Sample_Info_SW_Tag": formatted_samples_str, "Inferred_Direction": direction, # Clarify that this is inferred "Is_Long_Node": is_long_node_str, }) if segments_missing_sw_tag > 0: self.logger.warning(f"{segments_missing_sw_tag} segments were processed without an SW tag.") # Create DataFrame and export to CSV if not nodes_data_list: self.logger.info("No node data to export.") return df_nodes = DataFrame(nodes_data_list) try: df_nodes.to_csv(output_csv_path, index=False) self.logger.info(f"Node information successfully exported to '{output_csv_path}'.") # Simple console confirmation print(f"Node information successfully exported to {output_csv_path}") except IOError as e: self.logger.error(f"Failed to write nodes CSV to '{output_csv_path}': {e}")