Source code for gratools.Graph

# Standard library imports
import asyncio
import gzip
import logging
import sqlite3
import subprocess
import tempfile
import re
from collections import Counter, OrderedDict, defaultdict, namedtuple
from concurrent.futures import ThreadPoolExecutor, wait, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Set, TextIO
import heapq
from statistics import mean, median
from itertools import chain
from copy import copy

# Third-party imports
import aiofiles
import aiosqlite
from numpy import histogram
import numpy as np
from pandas import set_option, DataFrame, merge
import uvloop
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pybedtools import BedTool, cleanup
from pysam import AlignedSegment, AlignmentFile, index, qualitystring_to_array, view
from rich.panel import Panel
from rich.progress import (
    BarColumn,
    Progress,
    TaskID,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
)
from rich.table import Table
from rich import box
from natsort import natsorted

# Local application/library specific imports
from .logger_config import shared_console  # Assuming shared_console is used for rich printing
from .useful_function import RC_TRANS, reverse_complement_string
# Apply uvloop event loop policy
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())

# Configure display precision of floats for pandas
set_option("display.precision", 2)

LinkInfo = namedtuple(
    "LinkInfo",
    (
        "seg_id_1",
        "orient_seg_1",
        "orient_key_seg_1",
        "seg_id_2",
        "orient_seg_2",
        "orient_key_seg_2",
    ),
)
"""
Information about a link between two segments.

Attributes
----------
seg_id_1 : str
    Identifier of the first segment.
orient_seg_1 : int
    Orientation of the first segment (+1 or -1).
orient_key_seg_1 : int
    Orientation key for the first segment (often the same as orient_seg_1).
seg_id_2 : str
    Identifier of the second segment.
orient_seg_2 : int
    Orientation of the second segment (+1 or -1).
orient_key_seg_2 : int
    Orientation key for the second segment (often the same as orient_seg_2).
"""


[docs] @dataclass class SubGraph: """ A class representing a subgraph in a genomic analysis pipeline. It handles operations like BED file intersection, GFA walk/link/segment building, and FASTA sequence generation for a specific sample or region. Attributes ---------- bam_path : Path The file path to the BAM file. bed_path : Path The file path to the BED file. logger : logging.Logger The logger instance for logging messages. Defaults to a logger named "GraTools". sample_name : Optional[str] The name of the sample, derived from `bed_path`. Default to None. sample_name_query : Optional[str] The name of the sample being queried. Default to None. chromosome_query : Optional[str] The chromosome name for the query. Default to None. start_query : Optional[int] The start position on the chromosome for the query. Default to None. stop_query : Optional[int] The stop position on the chromosome for the query. Default to None. offset_first : int The offset for the first segment in the query region. Default to 0. offset_last : int The offset for the last segment in the query region. Default to 0. add_start_bases_first_segment : int Additional start bases for the first segment (usage seems specific, consider clarifying). Default to 0. intersect_bed : Optional[BedTool] The BedTool object for intersected BED regions. Default to None. segment_id_set : Set[str] A set storing unique segment IDs encountered during walk building. segment_id_first_query : Optional[str] The ID of the first segment in the query region. Default to None. segment_id_first_strand : Optional[str] The strand ('+' or '-') of the first segment ID in the query. Default to None. segment_id_last_query : Optional[str] The ID of the last segment ID in the query region. Default to None. segment_id_last_strand : Optional[str] The strand ('+' or '-') of the last segment ID in the query. Default to None. segment_id_first : Optional[str] The first segment ID encountered for the current sample (might be same as query). Default to None. segment_id_last : Optional[str] The last segment ID encountered for the current sample (might be same as query). Default to None. works_path : Optional[Path] The working directory path. Defaults to None. merge : Optional[int] Merge distance parameter for BED operations (e.g., `bedtools merge -d`). Default to None. build_fasta_flag : Optional[bool] Flag indicating whether FASTA sequences should be built. Default to False. gfa_walk_list : List[str] A list of GFA walk strings (W lines). Default to an empty list. gfa_link_list : List[str] A list of GFA link strings (L lines). Default to an empty list. gfa_segment_list : List[str] A list of GFA segment strings (S lines). Default to an empty list. dict_segments_samples : defaultdict[str, List[str]] A dictionary mapping segment IDs to a list of sample identifiers. Default to an empty defaultdict. dict_segments_sequence : defaultdict[str, str] A dictionary mapping segment IDs to their sequences. Default to an empty defaultdict. sequences_list : List[SeqRecord] A list of Biopython SeqRecord objects for generated FASTA sequences. Default to an empty list. progress_dict : Optional[Dict] A dictionary to track progress, typically for multi-processing. Default to None. task_id : Optional[TaskID] The task ID for progress tracking with `rich.progress`. Default to None. regions : Optional[List[Dict[str, Any]]] A list of regions (dictionaries with 'chromosome', 'start', 'stop'). Default to None. intersected_results_by_regions : Optional[BedTool] A BedTool object containing combined intersected results for all regions. Default to None. """ bam_path: Path bed_path: Path logger: logging.Logger = field( default_factory=lambda: logging.getLogger("GraTools") ) sample_name: Optional[str] = None sample_name_query: Optional[str] = None chromosome_query: Optional[str] = None start_query: Optional[int] = None stop_query: Optional[int] = None offset_first: int = 0 offset_last: int = 0 add_start_bases_first_segment: int = 0 intersect_bed: Optional[BedTool] = None segment_id_set: Set[str] = field(default_factory=set, repr=False) segment_id_first_query: Optional[str] = None segment_id_first_strand: Optional[str] = None segment_id_last_query: Optional[str] = None segment_id_last_strand: Optional[str] = None segment_id_first: Optional[str] = None # Tracks the first segment ID overall for this SubGraph instance segment_id_last: Optional[str] = None # Tracks the last segment ID overall for this SubGraph instance works_path: Optional[Path] = None merge: Optional[int] = None build_fasta_flag: Optional[bool] = False gfa_walk_list: List[str] = field(default_factory=list, repr=False) gfa_link_list: List[str] = field(default_factory=list, repr=False) gfa_segment_list: List[str] = field(default_factory=list, repr=False) dict_segments_samples: defaultdict = field( default_factory=lambda: defaultdict(list), repr=False ) dict_segments_sequence: defaultdict = field( default_factory=lambda: defaultdict(str), repr=False # Ensure defaultdict has a type ) sequences_list: List[SeqRecord] = field(default_factory=list, repr=False) progress_dict: Optional[Dict] = None task_id: Optional[TaskID] = None # Changed from int to TaskID for rich.progress type hint regions: Optional[List[Dict[str, Any]]] = None intersected_results_by_regions: Optional[BedTool] = None def __post_init__(self) -> None: """Initialize sample_name from bed_path.""" suffixes = "".join(self.bed_path.suffixes) # e.g., ['.bed', '.gz'] for file.bed.gz if ".bed.gz" in suffixes: # For name.bed.gz, self.bed_path.name.replace(".bed.gz", "") self.sample_name = self.bed_path.name.removesuffix(".bed.gz") elif ".bed" in suffixes: # For name.bed, self.bed_path.name.replace(".bed", "") self.sample_name = self.bed_path.name.removesuffix(".bed") else: self.logger.error(f"The file {self.bed_path} is not a recognized BED file type.")
[docs] def compute_intersection(self) -> None: """ Compute the intersection of BED regions for each region in self.regions. Store the results in `self.intersected_results_by_regions`. """ all_bed = None if not self.regions: # Default region if none provided self.regions = [ { "chromosome": self.chromosome_query, "start": self.start_query, "stop": self.stop_query, } ] if not isinstance(self.regions, list) or not all( isinstance(region, dict) and all(k in region for k in ["chromosome", "start", "stop"]) for region in self.regions ): self.logger.error( "Invalid regions provided. Must be a list of dictionaries, " "each with 'chromosome', 'start', and 'stop' keys." ) return # Exit if regions are invalid # Load BED file try: all_bed = BedTool(self.bed_path.as_posix()) except Exception as e: # pybedtools can raise various exceptions self.logger.error(f"Failed to load BED file '{self.bed_path}': {e}") return # Exit if BED file cannot be loaded # 3. Build a single BedTool for all regions, with index in the 4th column lines = [] for idx, reg in enumerate(self.regions): lines.append(f"{reg['chromosome']}\t{reg['start']}\t{reg['stop']}\t{idx}") regions_bed = BedTool("\n".join(lines), from_string=True) # 4. Single intersect call and saveas() try: intersect_all = all_bed.intersect(regions_bed, wa=True, wb=True) tb = intersect_all.saveas() # creates a temporary file on disk self.intersected_results_by_regions = tb if intersect_all.count() <= 0: # It's not an error for an intersection to be empty, but a warning/info might be useful. if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 7, "total": 7, # Mark as complete "description": f"[red]'{self.sample_name}': No intersection found for region, aborted.", } self.logger.error(f"'{self.sample_name}': No intersection found for region {self.regions}") except Exception as e: self.logger.error(f"Error during single intersection: {e}") return
[docs] def build_walks(self) -> None: """Build GFA walks (W lines) and links (L lines) from the intersected BED regions.""" self.logger.info(f"'{self.sample_name}': Building subgraph GFA walks and links") self.segment_id_set = set() # Reset for current build if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 5, "total": 7, "description": f"[cyan]'{self.sample_name}': Building walks and links", } old_region = None # defaultdict for walks: Key= (chrom_name, haplo_idx), Value= dict of fragments # Fragment dict: Key= chrom_fragment_id, Value= data for this fragment walk_dict = defaultdict( lambda: defaultdict( lambda: { "seg_list": [], "start": float("inf"), # Chromosomal start of the walk part "stop": float("-inf"), # Chromosomal end of the walk part } ) ) # Iterate region by region for intersect_region_bedtool in self.intersected_results_by_regions: fields = intersect_region_bedtool # fields = colonnes A + colonnes B (4 cols: chrom,start,stop,index) region_idx = int(fields[-1]) if region_idx != old_region: old_region = region_idx region_start = None chromosome_name = fields.chrom segment_start_on_chr = fields.start # Start of this segment on the chromosome segment_stop_on_chr = fields.stop # End of this segment on the chromosome # GFA segment ID with orientation (e.g., +seg1, -seg2) oriented_segment_id = fields[3] # Fragment identifier on the chromosome (e.g., path fragment from original BED) chromosome_fragment_id = fields[4] haplotype_index = fields[5] if region_start is None: region_start = segment_start_on_chr # Add segment ID (without orientation) to the set of all segments self.segment_id_set.add(str(oriented_segment_id[1:])) # Store first/last segment info if this is the query sample if self.sample_name == self.sample_name_query: if not self.segment_id_first_query: # First segment encountered for the query sample self.segment_id_first_query = oriented_segment_id[1:] self.segment_id_first_strand = oriented_segment_id[0] self.segment_id_first = self.segment_id_first_query # Also set the general first segment # Always update last segment for query sample to track the true end self.segment_id_last_query = oriented_segment_id[1:] self.segment_id_last_strand = oriented_segment_id[0] self.segment_id_last = self.segment_id_last_query # Also set the general last segment # Group segments by (chromosome, haplotype, fragment_id) key = f"{chromosome_name}:{haplotype_index}:{region_idx}" # Walk identifier part fragment_data = walk_dict[key][chromosome_fragment_id] fragment_data["seg_list"].append(oriented_segment_id) fragment_data["start"] = region_start fragment_data["stop"] = segment_stop_on_chr # Process collected walks for chr_name_haplo_key, fragments_dict in walk_dict.items(): chr_name, haplotype_idx, _ = chr_name_haplo_key.split(":") for fragment_id, data in fragments_dict.items(): # These are the actual genomic coordinates covered by this piece of the walk walk_genomic_start = data["start"] walk_genomic_stop = data["stop"] ordered_segments = data["seg_list"] # Segments in order of appearance for this walk fragment # Create GFA W line components walk_line_parts = [ "W", self.sample_name, haplotype_idx, chr_name, str(walk_genomic_start), str(walk_genomic_stop), ] # Build L lines from these ordered segments self._build_links(ordered_segments) # Pass the list of oriented segment IDs # Create GFA path string (e.g., >seg1<seg2>seg3) # The replace calls are for GFA spec compliance (> for forward, < for reverse) gfa_path_string = "".join(s.replace("+", ">").replace("-", "<") for s in ordered_segments) walk_line_parts.append(gfa_path_string) # TODO: save to file for concat after on gratools self.gfa_walk_list.append("\t".join(walk_line_parts))
def _build_links(self, segments_order: List[str]) -> None: """ Build GFA link lines (L lines) from an ordered list of oriented segment IDs. Parameters ---------- segments_order : List[str] Ordered list of GFA segment IDs, each prefixed with orientation (e.g., ["+seg1", "-seg2", "+seg3"]). """ if len(segments_order) < 2: # Not enough segments to form a link return for i in range(len(segments_order) - 1): seg1_oriented = segments_order[i] seg2_oriented = segments_order[i + 1] orient1_char = seg1_oriented[0] # '+' or '-' seg1_id = seg1_oriented[1:] orient2_char = seg2_oriented[0] # '+' or '-' seg2_id = seg2_oriented[1:] # Create L line: L <source_seg> <source_orient> <dest_seg> <dest_orient> <overlap_CIGAR> # Assuming 0M overlap for simplicity, as is common in many GFA1.1 outputs. txt_link = f"L\t{seg1_id}\t{orient1_char}\t{seg2_id}\t{orient2_char}\t0M" self.gfa_link_list.append(txt_link)
[docs] def build_segments(self) -> None: """Build GFA segments (S lines) from the BAM file for segments in `self.segment_id_set`.""" if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 6, "total": 7, "description": f"[cyan]'{self.sample_name}': Building segments", } self.logger.info(f"'{self.sample_name}': Building subgraph GFA segments") if not self.segment_id_set: self.logger.warning(f"'{self.sample_name}': No segment IDs for segment lines building. Skipping.") self.gfa_segment_list = [] self.dict_segments_samples = defaultdict(list) self.dict_segments_sequence = defaultdict(str) return bam_handler = GratoolsBam(bam_path=self.bam_path, logger=self.logger) # Pass logger ( self.gfa_segment_list, self.dict_segments_samples, self.dict_segments_sequence, ) = bam_handler.build_segments(list_segments=self.segment_id_set)
[docs] def filter_bed_with_awk(self) -> BedTool: """ Filter a BED file using an awk command line to extract only lines containing an ID of interest. """ self.logger.info(f"'{self.sample_name}': Starting awk filtering for {len(self.segment_id_set)} IDs.") segment_id_file = self.works_path / f"{self.sample_name}_segment_ids.txt" filtered_bed_path = self.works_path / f"{self.sample_name}_filtered.bed" # Save IDs to a file with open(segment_id_file, "w") as f: for seg_id in self.segment_id_set: f.write(f"{seg_id}\n") awk_script = f'''awk 'NR==FNR {{ids[$1]; next}} substr($4, 2) in ids' "{segment_id_file}" "{self.bed_path}" > "{filtered_bed_path}"''' with open(filtered_bed_path, "w") as out_f: subprocess.run(awk_script, shell=True, check=True, stdout=out_f) self.logger.info(f"'{self.sample_name}': Filtered file with awk written to: {filtered_bed_path}") return BedTool(filtered_bed_path), filtered_bed_path, segment_id_file
[docs] def get_chr_pos(self, progress_dict: Optional[Dict], task_id: Optional[TaskID]) -> None: """ Identify chromosomal regions corresponding to segments in `self.segment_id_set`, then compute intersections, and finally build walks and segments for these regions. Parameters ---------- progress_dict : Optional[Dict] A dictionary to track progress for multiprocessing. task_id : Optional[TaskID] The task ID for `rich.progress` tracking. """ self.progress_dict = progress_dict self.task_id = task_id if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 2, "total": 7, "description": f"[cyan]'{self.sample_name}': Searching corresponding chr pos", } filtered_bed_all, filtered_bed_path, segment_id_file = self.filter_bed_with_awk() if not filtered_bed_all: if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 7, "total": 7, "description": f"[red]'{self.sample_name}': No segment in BED.", } self.logger.warning( f"'{self.sample_name}': No matching segment found in BED file '{self.bed_path}' " f"for segment IDs: {list(self.segment_id_set)[:5]}..." # Log a few example IDs ) segment_id_file.unlink(missing_ok=True) filtered_bed_path.unlink(missing_ok=True) return self.progress_dict[self.task_id] = { "progress": 3, "total": 7, # Mark as complete "description": f"'{self.sample_name}': BED filtering successful, proceeding to merge..", } if self.works_path: file_output_all = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_all.csv" file_output_collapse = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_collapsed_d{self.merge}.csv" # Convert to DataFrame after merging (without -d, i.e., only abutting features) merged_all_bed = filtered_bed_all.merge(c=[4, 5, 6], o=["collapse", "distinct", "distinct"]) # bedtools merge columns: name, score, strand merged_all_bed.saveas(file_output_all) merged_collapse_bed = merged_all_bed.merge(d=self.merge, c=[4, 5, 6], o=["collapse", "distinct", "distinct"]) merged_collapse_bed.saveas(file_output_collapse) # Define regions for subsequent intersection based on the collapsed dataframe self.logger.info(f"'{self.sample_name}': compute regions base on merge dataframe") self.regions = [ {"chromosome": feature.chrom, "start": feature.start, "stop": feature.stop} for feature in merged_collapse_bed ] if not self.regions: self.logger.warning(f"'{self.sample_name}': No regions found for processing.") self.progress_dict[self.task_id] = { "progress": 7, "total": 7, "description": f"[red]'{self.sample_name}':No regions found for processing", } self.logger.debug( f"'{self.sample_name}': Found {len(self.regions)} regions for processing. Example: {self.regions[0] if self.regions else 'None'}." ) if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 4, "total": 7, "description": f"[cyan]'{self.sample_name}': Found corresponding chr pos", } if Path(filtered_bed_path).exists(): filtered_bed_path.unlink() if Path(segment_id_file).exists(): segment_id_file.unlink() # Process regions: compute intersection, build walks, build segments self.compute_intersection() self.build_walks() # build_walks populates self.segment_id_set self.build_segments() # build_segments uses self.segment_id_set self.intersected_results_by_regions = None # Log completion if self.progress_dict and self.task_id is not None: self.progress_dict[self.task_id] = { "progress": 7, "total": 7, "description": f"[green]'{self.sample_name}': End processing", }
[docs] def build_fasta(self) -> None: """ Build FASTA sequences from the GFA walks (`self.gfa_walk_list`). This method recovers FASTA sequences by processing each walk, extracting the constitutive segments, and applying specific offsets if the current sample is the query sample (`self.sample_name == self.sample_name_query`). The offsets trim segment sequences at the beginning or end of the walk to match the precise query region boundaries (`self.start_query`, `self.stop_query`). The resulting sequences are stored as `SeqRecord` objects in `self.sequences_list`. """ num_walks = len(self.gfa_walk_list) self.logger.info(f"'{self.sample_name}': Building FASTA for {num_walks} walks.") if not self.dict_segments_sequence: self.logger.warning(f"'{self.sample_name}': Segment sequences dictionary is empty. Cannot build FASTA.") return if not self.gfa_walk_list: self.logger.info(f"'{self.sample_name}': No GFA walks to process for FASTA generation.") return newly_created_records: List[SeqRecord] = [] # Store records for this call for walk_line in self.gfa_walk_list: try: ( _w_char, # 'W' character current_walk_sample_name, # Sample name from the W line haplotype_index, chromosome_name, walk_genomic_start_str, walk_genomic_stop_str, gfa_path_specification, # e.g., ">seg1<seg2" or "seg1+seg2-" ) = walk_line.split("\t") except ValueError as e: self.logger.error( f"'{self.sample_name}': Malformed walk line during FASTA build: '{walk_line}'. Error: {e}", exc_info=True) continue walk_genomic_start = int(walk_genomic_start_str) walk_genomic_stop = int(walk_genomic_stop_str) # Initialize effective coordinates for this walk's sequence. # These will be adjusted by offsets for the query sample. effective_walk_chr_start = walk_genomic_start effective_walk_chr_stop = walk_genomic_stop # Calculate offsets if this is the query sample. # These offsets represent how much to trim from the start/end of the *entire walk's sequence* # if the walk extends beyond the query region. if self.sample_name == self.sample_name_query: if self.start_query is not None and self.stop_query is not None: # Ensure query bounds are set self.offset_first = self.start_query - walk_genomic_start self.offset_last = walk_genomic_stop - self.stop_query else: self.logger.warning( f"'{self.sample_name}': Query sample set, but start/stop_query not fully set. Offsets may be incorrect.") # Parse GFA path string (e.g., ">s1<s2" or "s1+s2-") into oriented segments # GFA1 W line path: >s1>s2<s3... or s1+s2-s3+... oriented_segments_in_path: List[Tuple[str, str]] = [] # List of (orientation_char, segment_id) if ">" in gfa_path_specification[0] or "<" in gfa_path_specification[0]: # Format like >s1<s2 matches = re.findall(r"([<>])([^<>]+)", gfa_path_specification) for orient_map_char, seg_id_str in matches: oriented_segments_in_path.append(("+" if orient_map_char == ">" else "-", seg_id_str)) else: # Format like s1+s2- (less common in W lines but GFA spec allows segment names with +/-) matches = re.findall(r"(.+?)([+-])", gfa_path_specification) for seg_id_str, orient_char in matches: oriented_segments_in_path.append((orient_char, seg_id_str)) if not oriented_segments_in_path: self.logger.warning( f"'{self.sample_name}': Could not parse path specification '{gfa_path_specification}' in walk line: '{walk_line}'. Skipping this walk for FASTA building.") continue num_segments_in_walk = len(oriented_segments_in_path) # List to hold Bio.Seq fragments for constructing the full sequence of this walk current_walk_seq_fragments: List[str] = [] self.logger.debug( f"'{self.sample_name}': Processing walk on {chromosome_name} ({walk_genomic_start}-{walk_genomic_stop}) " f"with {num_segments_in_walk} segments for FASTA." ) for segment_idx, (segment_orient_char, segment_id) in enumerate(oriented_segments_in_path): if segment_id not in self.dict_segments_sequence: self.logger.warning( f"'{self.sample_name}': Sequence for segment '{segment_id}' in walk not found. Skipping this segment.") # Potentially add a placeholder like 'N's if a continuous sequence is critical, # or simply skip, which might create a shorter/disjointed sequence. continue current_segment_seq = self.dict_segments_sequence[segment_id] if segment_orient_char == "-": current_segment_seq = reverse_complement_string(current_segment_seq) # Apply complex offsetting logic from the original code IF this is the query sample if self.sample_name == self.sample_name_query: # These checks need self.segment_id_first_query, self.segment_id_first_strand, etc. # to be correctly set (usually when the query SubGraph itself was processed). # And self.offset_first / self.offset_last calculated above. # Check if current segment is the designated "first query segment" if segment_id == self.segment_id_first_query: if segment_idx == 0: # Is it the *first segment of this walk*? # And does its orientation match the expected start orientation? if segment_orient_char == self.segment_id_first_strand: if self.offset_first > 0: # Positive offset means query starts *after* walk's genomic start if self.offset_first < len(current_segment_seq): current_segment_seq = current_segment_seq[self.offset_first:] effective_walk_chr_start = walk_genomic_start + self.offset_first else: # Offset is >= segment length, segment is fully trimmed self.logger.debug( f"Offset_first ({self.offset_first}bp) consumes entire first segment '{segment_id}' ({len(current_segment_seq)}bp).") current_segment_seq = "" # else: offset_first <= 0, means query starts at or before walk start, no trim from beginning needed. elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*? # This case is for when the *first query segment* also happens to be the *last segment of the walk*, # and its orientation is opposite to the expected start orientation. # This is a very specific scenario, implying the walk might be very short or circular in a way. if segment_orient_char != self.segment_id_first_strand: # Here, offset_first is used to trim from the *end* of the segment. # This implies offset_first was calculated based on the query *start*, but is applied to the *end* # of this segment if it's the first_query_segment but found at the end of the walk in reverse. if self.offset_first > 0: # If an offset from the start of the query was calculated if self.offset_first < len(current_segment_seq): current_segment_seq = current_segment_seq[ :-self.offset_first] # Trim from end effective_walk_chr_stop = walk_genomic_stop - self.offset_first else: current_segment_seq = "" # Check if current segment is the designated "last query segment" if segment_id == self.segment_id_last_query: if segment_idx == 0: # Is it the *first segment of this walk*? # This case is for when the *last query segment* is also the *first segment of the walk*, # and its orientation is opposite to the expected end orientation. if segment_orient_char != self.segment_id_last_strand: # Here, offset_last is used to trim from the *beginning* of the segment. # offset_last was calculated based on query *end*. if self.offset_last > 0: # Positive offset_last means query ends *before* walk's genomic end if self.offset_last < len(current_segment_seq): current_segment_seq = current_segment_seq[ self.offset_last:] # Trim from beginning effective_walk_chr_start = walk_genomic_start + self.offset_last else: current_segment_seq = "" elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*? # And does its orientation match the expected end orientation? if segment_orient_char == self.segment_id_last_strand: if self.offset_last > 0: if self.offset_last < len(current_segment_seq): current_segment_seq = current_segment_seq[:-self.offset_last] effective_walk_chr_stop = walk_genomic_stop - self.offset_last else: self.logger.error( f"Offset_last ({self.offset_last}bp) consumes entire last segment '{segment_id}' ({len(current_segment_seq)}bp).") current_segment_seq = "" if len(current_segment_seq) > 0: current_walk_seq_fragments.append(current_segment_seq) # After processing all segments for this walk if not current_walk_seq_fragments: self.logger.debug( f"'{self.sample_name}': No sequence fragment generated for walk on {chromosome_name} " f"(ID: {current_walk_sample_name}-{haplotype_index}). Skipping FASTA record for this walk.") continue final_walk_sequence = Seq("".join(current_walk_seq_fragments)) if len(final_walk_sequence) == 0: self.logger.debug( f"'{self.sample_name}': Resulting sequence for walk on {chromosome_name} is empty. Skipping FASTA.") continue # Construct FASTA record ID and description fasta_seq_id: str fasta_description: str if self.sample_name == self.sample_name_query: # For the query sample, use precise query coordinates in the ID if available # Ensure query parameters are fully defined for this formatting q_sample = self.sample_name_query or "query_sample" q_chrom = self.chromosome_query or chromosome_name # Fallback to walk's chrom if query_chrom not set # Use effective_walk_chr_start/stop which were adjusted by offsets fasta_seq_id = f"{q_sample}-{q_chrom}:{effective_walk_chr_start}-{effective_walk_chr_stop}" fasta_description = f"" else: # For other (target/non-query) samples fasta_seq_id = f"{current_walk_sample_name}-{chromosome_name}:{effective_walk_chr_start}-{effective_walk_chr_stop}" # Description includes length of this sample's sequence and the query region it corresponds to seq_actual_len = effective_walk_chr_stop - effective_walk_chr_start query_region_str = "N/A" if self.sample_name_query and self.chromosome_query and \ self.start_query is not None and self.stop_query is not None: query_region_str = f"{self.sample_name_query}-{self.chromosome_query}:{self.start_query}-{self.stop_query}" fasta_description = f"SEQ_LEN={seq_actual_len}|QUERY={query_region_str}" # Ensure ID is suitable for FASTA (e.g., no problematic spaces at end) fasta_seq_id = fasta_seq_id.strip() seq_record = SeqRecord( final_walk_sequence, id=fasta_seq_id, name=fasta_seq_id, # Often same as id description=fasta_description ) newly_created_records.append(seq_record) # After processing all walks self.sequences_list.extend(newly_created_records) # Add all new records to the main list self.logger.info( f"'{self.sample_name}': Finished FASTA building. Added {len(newly_created_records)} new records. " f"Total records in sequences_list: {len(self.sequences_list)}.")
[docs] class AsyncGfaDatabase: """ Manage an asynchronous SQLite database for storing and querying GFA link data. It uses `aiosqlite` for non-blocking operations within an asyncio event loop and serializes writes via an internal FIFO queue to prevent SQLite lock contention. Attributes ---------- db_file : Path Path to the SQLite database file. timeout : float Maximum timeout (in seconds) for SQLite lock acquisition. logger : logging.Logger Logger instance. _conn : Optional[aiosqlite.Connection] Shared SQLite connection (or None if not connected). _write_queue : asyncio.Queue Asynchronous queue for batches of links to be inserted. Max size 100. _sql_task : Optional[asyncio.Task] Background task consuming the queue and writing to the database. _shutdown : bool Flag to signal shutdown to the writer task. """
[docs] def __init__(self, db_file: Path, timeout: float = 30.0): """ Initialize the instance without opening the connection. Scheme is created upon the first call to `connect()`. Args: db_file (Path): Path to the SQLite file to be used as backend. timeout (float, optional): Maximum wait time for SQLite locks (default 30.0s). """ self.db_file: Path = db_file self.timeout: float = timeout # In seconds self._shutdown: bool = False self._conn: Optional[aiosqlite.Connection] = None self._write_queue: asyncio.Queue[Optional[List[Tuple[str, int, int, str, int, int]]]] = asyncio.Queue( maxsize=1000) self._sql_task: Optional[asyncio.Task] = None self.logger = logging.getLogger("GraTools") # More specific logger name
[docs] async def connect(self) -> None: """ Connect to the SQLite db (if not already connected), configure PRAGMA settings, create the 'links' table schema (if it doesn't exist), and start the SQL writer task. This method is idempotent: if already connected, it does nothing. """ if self._conn: return self.logger.debug(f"Connecting to the SQLite database: {self.db_file.name}") # Open the single connection self._conn = await aiosqlite.connect( self.db_file.as_posix(), timeout=self.timeout # timeout here is for connection ) # Performance optimizations and configuration await self._conn.execute("PRAGMA synchronous = OFF;") # Faster, but higher risk of corruption on crash await self._conn.execute("PRAGMA journal_mode = WAL;") # Write-Ahead Logging for better concurrency await self._conn.execute( f"PRAGMA busy_timeout = {int(self.timeout * 1000)};") # busy_timeout is in milliseconds # Create schema await self._conn.execute(""" CREATE TABLE IF NOT EXISTS links ( seg_id_1 TEXT, orient_seg_1 INTEGER, -- +1 for '+', -1 for '-' orient_key_seg_1 INTEGER, -- Often same as orient_seg_1, purpose might be specific seg_id_2 TEXT, orient_seg_2 INTEGER, orient_key_seg_2 INTEGER ); """) await self._conn.commit() # Start the background insertion task if self._sql_task is None or self._sql_task.done(): self._sql_task = asyncio.create_task(self._sql_writer()) self.logger.debug("AsyncGfaDatabase connected and writer task started/confirmed.")
async def _sql_writer(self) -> None: """ Background task that continuously fetches link batches from the queue and inserts them into the database. It handles shutdown signals and manages SQLite transactions for batch inserts. """ assert self._conn is not None, "Connection must be initialized before writer starts" self.logger.debug("SQL writer task started.") processed_data_batches = 0 # Counter for actual data batches while True: # The exit condition is handled by internal breaks batch = None item_was_retrieved = False try: # Wait for an item from the queue with a timeout # The timeout allows periodic checking of _shutdown batch = await asyncio.wait_for(self._write_queue.get(), timeout=0.5) item_was_retrieved = True if batch is None: # Sentinel to stop the writer self.logger.debug("SQL writer received end. Preparing to stop.") break # Exit the main while loop if not batch: # An empty batch (empty list), not the sentinel # task_done will be called in finally continue # Go to the next iteration to get another batch # This is a valid data batch processed_data_batches += 1 self.logger.debug( f"SQL writer processing data batch {processed_data_batches} of size {len(batch)}. " f"Queue size: {self._write_queue.qsize()}" ) # Database insertion await self._conn.executemany( """INSERT INTO links ( seg_id_1, orient_seg_1, orient_key_seg_1, seg_id_2, orient_seg_2, orient_key_seg_2 ) VALUES (?, ?, ?, ?, ?, ?);""", batch, ) await self._conn.commit() # self.logger.debug(f"SQL writer successfully committed data batch {processed_data_batches}.") except asyncio.TimeoutError: # Timeout while waiting for an item from the queue if self._shutdown and self._write_queue.empty(): self.logger.info("SQL writer: Timeout on get() with _shutdown True and queue empty. Stopping.") break # Exit the main while loop continue except sqlite3.Error as e: self.logger.error( f"SQLite error during processing of batch (approx. #{processed_data_batches + 1}): {e}", exc_info=True) if self._conn.in_transaction: try: await self._conn.rollback() self.logger.warning("SQL writer rolled back transaction due to SQLite error.") except sqlite3.Error as rb_err: self.logger.error(f"Error during SQLite rollback: {rb_err}") # Continue to try and process other batches, the current batch is lost without re-queuing. except Exception as e: self.logger.error(f"Unexpected error in SQL writer (approx. batch #{processed_data_batches + 1}): {e}", exc_info=True) # Handle transaction if necessary, as for sqlite3.Error finally: if item_was_retrieved: self._write_queue.task_done() # if batch is not None and batch: # If it was an actual data batch # self.logger.debug(f"SQL writer marked data batch {processed_data_batches} as done.") # No special log for an empty batch after task_done self.logger.debug("SQL writer task has stopped.")
[docs] async def create_indexes(self) -> None: """ Create indexes on seg_id_1 and seg_id_2 to accelerate queries. Should ideally be called after all data insertions are complete. """ if not self._conn: await self.connect() self.logger.info("Creating indexes on links table.") try: await self._conn.execute( "CREATE INDEX IF NOT EXISTS idx_seg_id_1 ON links(seg_id_1);" ) await self._conn.execute( "CREATE INDEX IF NOT EXISTS idx_seg_id_2 ON links(seg_id_2);" ) await self._conn.commit() except sqlite3.Error as e: self.logger.error(f"Failed to create indexes: {e}", exc_info=True)
[docs] async def find_children_and_grandchildren( self, node_id: str ) -> Dict[str, List[str]]: """ Find direct successors (children) and second-degree successors (grandchildren) of a segment. A child is `seg_id_2` where `node_id` is `seg_id_1`. A grandchild is a child of a child. Args: node_id : The starting segment ID. Returns: A dictionary: {"children": [IDs], "grandchildren": [IDs]}. """ if not self._conn: await self.connect() self.logger.debug(f"Finding children and grandchildren for node: {node_id}") children: List[str] = [] grandchildren: List[str] = [] try: # Direct children async with self._conn.execute( "SELECT seg_id_2 FROM links WHERE seg_id_1 = ?;", (node_id,) ) as cur_children: children = [r[0] for r in await cur_children.fetchall()] # r is a tuple e.g. ('seg_X',) self.logger.debug(f"Found {len(children)} children for {node_id}.") # Grandchildren (children of children) if children: # Create placeholders for the IN clause: (?, ?, ...) placeholders = ",".join("?" for _ in children) sql_grandchildren = f"SELECT seg_id_2 FROM links WHERE seg_id_1 IN ({placeholders});" async with self._conn.execute(sql_grandchildren, children) as cur_gc: grandchildren_tuples = await cur_gc.fetchall() # Deduplicate grandchildren, as multiple children might lead to the same grandchild. grandchildren = sorted(list(set(r[0] for r in grandchildren_tuples))) self.logger.debug( f"Found {len(grandchildren)} unique grandchildren for {node_id} (from {len(grandchildren_tuples)} raw)." ) except sqlite3.Error as e: self.logger.error(f"Error finding children/grandchildren for {node_id}: {e}", exc_info=True) # Return lists as they are up to the point of error. return {"children": children, "grandchildren": grandchildren}
[docs] async def close(self) -> None: """ Properly shut down the database: - Signal the SQL writer task to stop. - Wait for the writer task to finish processing its queue (with timeout). - Cancel the task if it doesn't finish in time. - Create indexes (important to do this after all writes). - Close the SQLite connection. """ self.logger.debug("Initiating AsyncGfaDatabase shutdown...") self._shutdown = True # Signal writer to stop after draining queue if self._write_queue: self.logger.debug("Putting sentinel on SQL write queue.") await self._write_queue.put(None) # Sentinel to stop the writer loop self.logger.debug("Waiting for SQL queue to drain...") try: # Wait for all items in queue to be processed. # Timeout for queue.join() should be generous if large batches or slow I/O expected. await self._write_queue.join() self.logger.debug("SQL queue drained successfully (all items processed).") except asyncio.TimeoutError: self.logger.warning("Timeout waiting for SQL queue to drain during close.") except Exception as e: self.logger.error(f"Error waiting for SQL queue to drain: {e}", exc_info=True) if self._sql_task and not self._sql_task.done(): self.logger.debug("SQL writer task still running, attempting to cancel.") self._sql_task.cancel() try: await self._sql_task # Wait for task to acknowledge cancellation self.logger.debug("SQL writer task cancelled successfully.") except asyncio.CancelledError: self.logger.info("SQL writer task was cancelled as expected during close.") except Exception as e: self.logger.error(f"Error awaiting SQL writer task cancellation: {e}", exc_info=True) elif self._sql_task and self._sql_task.done(): self.logger.debug("SQL writer task was already done.") else: self.logger.debug("No SQL writer task to shut down or task was never started.") # Create indexes after all data is written and writer task is stopped if self._conn: # Only if connection was ever established await self.create_indexes() try: await self._conn.close() except Exception as e: # Catch potential errors during aiosqlite close self.logger.error( f"Error closing SQLite connection: {e}", exc_info=True ) finally: self._conn = None # Ensure _conn is None after attempting to close
[docs] class AsyncBedWriter: """ Asynchronous BED file writer using `aiofiles`. It buffers lines per sample and writes them in batches to separate BED files. Attributes ---------- bed_dir : Path Output directory for .bed files. batch_size : int Number of lines to buffer per sample before an automatic flush. progress : Optional[Progress] Rich Progress instance for displaying progress (if provided). logger : logging.Logger Logger instance. _queue : asyncio.Queue[Tuple[str, List[str]]] Internal queue for (sample_name, lines_to_write) tuples. Max size 100. _shutdown : bool Flag to signal shutdown to the writer loop. _task : Optional[asyncio.Task] The background asyncio task running the `_writer_loop`. """ def __init__(self, bed_dir: Path, batch_size: int = 1, progress: Optional[Progress] = None): self.bed_dir: Path = bed_dir self.batch_size: int = batch_size self._queue: asyncio.Queue[Tuple[str, List[str]]] = asyncio.Queue(maxsize=1000000) self._shutdown: bool = False self._task: Optional[asyncio.Task] = None self.progress: Optional[Progress] = progress # Can be None self.logger = logging.getLogger("GraTools")
[docs] def start(self) -> None: """Start the writer loop as a background task if not already running.""" if self._task is None or self._task.done(): self._shutdown = False # Reset shutdown flag if restarting self.logger.debug("Starting AsyncBedWriter loop task.") self._task = asyncio.create_task(self._writer_loop()) self.logger.debug("AsyncBedWriter task launched.") else: self.logger.debug("AsyncBedWriter task is already running.")
[docs] async def enqueue(self, sample: str, lines: List[str]) -> None: """ Add a sample and its corresponding lines to the write queue. Use `put_nowait` assuming the queue rarely fills; consider `await self._queue.put()` if backpressure to the producer is acceptable when the queue is full. Args: sample (str): The sample name (used for filename). lines (List[str]): A list of strings (lines) to write to the BED file. Each string should include its own newline character if needed. """ if not lines: self.logger.debug(f"AsyncBedWriter: enqueue called with empty lines for sample '{sample}'. Skipping.") return try: self.logger.debug( f"BED writer enqueuing {len(lines)} lines for sample '{sample}'. Current queue size: {self._queue.qsize()}" ) self._queue.put_nowait((sample, lines)) # Can raise asyncio.QueueFull if maxsize is hit except asyncio.QueueFull: self.logger.debug( f"AsyncBedWriter queue is full. Failed to enqueue {len(lines)} lines for sample '{sample}'. " "Consider increasing queue maxsize or slowing down producers." )
# OPTIONAL: Implement retry logic or a different strategy for full queue.
[docs] async def enqueue_single_line(self, sample: str, line: str): """ Asynchronously adds a single line to the write queue. This will wait if the queue is full, providing backpressure. """ if not line: return try: # Utilisez `await self._queue.put` pour attendre si la file est pleine. # C'est la bonne faรงon de gรฉrer la contre-pression (backpressure). await self._queue.put((sample, line)) except asyncio.CancelledError: # Si la tรขche qui appelle cette mรฉthode est annulรฉe, on propage. self.logger.warning(f"Enqueue operation for sample '{sample}' was cancelled.") raise
async def _writer_loop(self) -> None: """ Consume the queue and writes lines to sample-specific BED files in batches. Manage buffers for each sample. """ self.logger.debug("BED writer loop started.") buffers: Dict[str, List[str]] = defaultdict(list) try: self.bed_dir.mkdir(parents=True, exist_ok=True) except OSError as e: self.logger.error(f"Failed to create BED output directory {self.bed_dir}: {e}. Writer loop cannot proceed.") return while not self._shutdown or not self._queue.empty(): sample_name_for_log = "" try: # Obtenir une seule ligne (item) current_sample, new_line = await asyncio.wait_for(self._queue.get(), timeout=0.5) buffers[current_sample].extend(new_line) # Flush buffer for this sample if it's full if len(buffers[current_sample]) >= self.batch_size: self.logger.debug( f"BED writer flushing BED buffer for sample '{current_sample}' (size: {len(buffers[current_sample])})." ) await self._flush_sample(current_sample, buffers[current_sample]) # _flush_sample clears the buffer passed # buffers[current_sample] is now empty buffers[current_sample] = list() except asyncio.TimeoutError: # No item from queue within timeout, check shutdown status if self._shutdown and self._queue.empty(): break # Continue polling if not shutting down or queue not empty continue except asyncio.CancelledError: self.logger.info("BED writer loop was cancelled.") # Perform a final flush before exiting due to cancellation await self._flush_all(buffers) raise # Re-raise CancelledError to signal cancellation to caller except Exception as e: self.logger.error(f"Error in BED writer loop processing item for '{sample_name_for_log}': {e}", exc_info=True) finally: # This check is crucial: task_done should only be called if queue.get() succeeded. # If TimeoutError or CancelledError occurred in wait_for, get() didn't return. if 'current_sample' in locals(): # Indicates get() succeeded self._queue.task_done() # self.logger.debug(f"BED writer marked item for sample '{current_sample}' as done.") del current_sample # Clean up to avoid using stale var if next iteration errors early # Periodically flush all buffers if shutdown is requested (helps drain faster) if self._shutdown and not self._queue.empty(): await self._flush_all(buffers) # Flush all to expedite if shutting down # Final flush for any remaining data in buffers after loop finishes self.logger.debug("BED writer loop finished. Performing final flush of all buffers.") await self._flush_all(buffers) self.logger.debug("BED writer loop has stopped.") async def _flush_sample(self, sample: str, buffer_lines: List[str]) -> None: """Write the content of `buffer_lines` to the BED file for `sample` and clear `buffer_lines`.""" if not buffer_lines: self.logger.debug(f"No lines to flush for sample '{sample}'.") return path = self.bed_dir / f"{sample}.bed" # Parent directory creation is handled in _writer_loop start now. try: async with aiofiles.open(path, mode="a", encoding="utf-8") as f: # Assuming lines in buffer_lines already have newline characters. # If not, add them: await f.write("".join(line + "\n" for line in buffer_lines)) await f.writelines(buffer_lines) # self.logger.debug(f"Successfully flushed {len(buffer_lines)} lines to '{path.name}'.") buffer_lines.clear() # Clear the buffer that was passed, after successful write except IOError as e: # More specific exception for file operations self.logger.error( f"IOError flushing sample '{sample}' to '{path}': {e}", exc_info=True ) except Exception as e: self.logger.error( f"Unexpected error flushing sample '{sample}' to '{path}': {e}", exc_info=True ) async def _flush_all(self, buffers: Dict[str, List[str]]) -> None: """Write all non-empty buffers to their respective files.""" self.logger.debug(f"Performing flush of all remaining buffers ({len(buffers)} total).") # Iterate over a copy of items if modifying dict, but here we pass buffer to _flush_sample which clears it. for sample_name, buf_list in list(buffers.items()): # list() for safe iteration if _flush_sample modified dict if buf_list: # Only flush if buffer has content await self._flush_sample(sample_name, buf_list) if not buf_list and sample_name in buffers: # If buffer became empty (cleared by _flush_sample) del buffers[sample_name] # Clean up empty entries from the main buffers dict self.logger.debug("Flush of all buffers complete.")
[docs] async def shutdown(self) -> None: """ Signal the writer loop to shut down and wait for it to complete. Ensure all pending data is flushed. """ self._shutdown = True if self._task is not None and not self._task.done(): self.logger.debug("Waiting for BED writer task to complete...") try: # The writer loop itself handles queue draining on shutdown signal. # We just need to wait for the task to finish. await asyncio.wait_for(self._task, timeout=None) self.logger.debug("BED writer task completed successfully after shutdown signal.") except asyncio.TimeoutError: self.logger.warning( "Timeout waiting for BED writer task to complete during shutdown. Attempting cancellation.") self._task.cancel() try: await self._task except asyncio.CancelledError: self.logger.info("BED writer task was cancelled during shutdown due to timeout.") except Exception as e_inner: self.logger.error(f"Exception after cancelling BED writer task: {e_inner}", exc_info=True) except asyncio.CancelledError: # If shutdown itself was called from a cancelled context self.logger.info("BED writer task was externally cancelled during its shutdown process.") # May need to re-ensure final flush if task was hard-cancelled. except Exception as e: self.logger.error( f"Exception while waiting for BED writer task during shutdown: {e}", exc_info=True ) elif self._task and self._task.done(): self.logger.debug("BED writer task was already done prior to shutdown call.") else: self.logger.debug( "No BED writer task to shut down (was never started or already None)." )
[docs] @dataclass class GFA: """ Manage parsing of a GFA (Graphical Fragment Assembly) file, compute statistics, and generate related files such as BAM-containing segments files and BED files for path per sample. Fill an asynchronous database for links and asynchronous BED writer. Attributes ---------- gfa_path : Path Path to the input GFA file (can be .gfa or .gfa.gz). threads : int, optional Number of threads for operations like BAM file processing. Default is 1. logger : logging.Logger Logger object. Default is a logger named "GraTools". gfa_name : Optional[str] Name of the GFA file derived from `gfa_path` (without extensions). Auto-initialized. version : Optional[str] GFA version extracted from the header (e.g., "1.0"). Auto-initialized. header_gfa : List[str] List of header lines (H lines) from the GFA file. Auto-initialized. sample_reference : Optional[str] Reference sample name, potentially from GFA header (RS tag). Auto-initialized. bam_segments_file : Optional[Path] Path to the BAM file where segments (S lines) will be written. Auto-initialized. dict_samples_chrom : defaultdict[str, OrderedDict[str, List[str]]] Map sample names to an OrderedDict of chromosome names, which in turn map to a list of "start\\tstop" fragment strings derived from Walk (W) lines. Auto-initialized. dict_segments_size : defaultdict[str, int] Map segment IDs and their length (in base pairs). Auto-initialized. dict_segments_samples : defaultdict[str, List[str]] Map segment IDs to a list of sample identifiers ("sample;chromosome;haplotype") that contain the segment. Auto-initialized. dict_samples_bed : defaultdict[str, OrderedDict[str, Path]] (Note: This attribute is not directly populated by the current parsing logic. It was likely intended to track the paths of generated BED files. The `AsyncBedWriter` internally manages these paths. This attribute might be redundant or used for post-processing tracking). Auto-initialized. works_path : Optional[Path] Path to the working directory (e.g., ".../{gfa_name}_GraTools-IMPORT"). Auto-initialized. bed_path : Optional[Path] Path to the subdirectory for BED files within `works_path`. Auto-initialized. bam_path : Optional[Path] Path to the subdirectory for BAM files within `works_path`. Auto-initialized. found_minigraph : bool Flag indicating if a sample named 'MINIGRAPH' (case-insensitive) was found in Walk lines. Defaults to False. import_links : bool, optional If True, GFA links (L lines) are stored in an SQLite database. Defaults to True. db_links : Optional[AsyncGfaDatabase] Asynchronous database handler for GFA links. Auto-initialized if `import_links` is True. segment_count : int Total number of segments (S lines) processed. Defaults to 0. total_segment_length : int Sum of lengths of all segments. Defaults to 0. link_count : int Total number of links (L lines) processed. Defaults to 0. degrees : defaultdict[str, int] Maps segment IDs to their degree (number of links connected). Defaults to an empty defaultdict. walks_count : int Total number of walks (W lines) processed. Defaults to 0. max_walk_rank : int Maximum number of segments in any single walk. Defaults to 0. sum_rank0_length : int Sum of lengths of the first segments of all walks. Defaults to 0. input_genome_size : int Cumulative size of all paths (sum of segment lengths along each walk). Defaults to 0. walks_info : List[Dict[str, Any]] List of dictionaries, each containing info for a walk (Path name, Sequence length, Num Segments). Defaults to an empty list. inverted_links_count : int Count of links where orientations differ (e.g., S1+ -> S2-). Defaults to 0. negative_links_count : int Count of links where both segments have negative orientation (S1- -> S2-). Defaults to 0. self_links_count : int Count of links where a segment links to itself (S1 -> S1). Defaults to 0. isolated_segments : Set[str] Set of segment IDs that have no links connected to them. Initialized with all segments, then linked ones removed. Defaults to an empty set. shared_executor : Optional[ThreadPoolExecutor] Executor for running synchronous tasks in threads. Auto-initialized. progress : Optional[Progress] Rich Progress instance for displaying progress. Auto-initialized. line_type_counts : Counter Counts of each GFA line type (H, S, L, W, P, C, E, U). Auto-initialized. header_gfa_file : Optional[Path] Path to where the GFA header is saved. Auto-initialized. stats_file : Optional[Path] Path to where GFA statistics are saved. Auto-initialized. db_file_path : Optional[Path] Path to the SQLite database file for links. Auto-initialized. RE_ORIENTED_SEG_GT_LT : re.Pattern Compiled regular expression for parsing oriented segments using '>' and '<'. Auto-initialized. RE_ORIENTED_SEG_PLUS_MINUS : re.Pattern Compiled regular expression for parsing oriented segments using '+' and '-'. Auto-initialized. disable_progress_flag: bool If True, progress bars are disabled. Defaults to False. """ gfa_path: Path threads: int = 1 logger: logging.Logger = field( default_factory=lambda: logging.getLogger("GraTools") ) disable_progress_flag: Optional[bool] = False gfa_name: Optional[str] = None version: Optional[str] = None header_gfa: List[str] = field(default_factory=list, repr=True) sample_reference: Optional[str] = None bam_segments_file: Optional[Path] = None # GFA samples processed_sample_names: Set[str] = field(init=False, repr=False) samples_chrom_file_writer: Optional[TextIO] = field(init=False, repr=False, default=None) samples_chrom_file_path: Optional[Path] = field(init=False, repr=False, default=None) # Dictionaries populated during parsing dict_samples_chrom: defaultdict = field( default_factory=lambda: defaultdict(OrderedDict), repr=False # Key: sample, Value: OrderedDict[chrom, list_of_fragments] ) dict_segments_size: defaultdict = field( default_factory=lambda: defaultdict(int), repr=False # Key: seg_id, Value: length ) dict_segments_samples: defaultdict = field( default_factory=lambda: defaultdict(list), repr=False # Key: seg_id, Value: list of ID to map to "sample;chrom;haplo" ) # dict_samples_bed is related to AsyncBedWriter output, not directly populated here in GFA parsing # It seems more like a tracker of where AsyncBedWriter IS writing files. # The GFA class _generates_ data for AsyncBedWriter. dict_samples_bed: defaultdict = field( # This seems to track which BED files are expected/created by AsyncBedWriter. default_factory=lambda: defaultdict(OrderedDict), repr=False # Key: sample, Value: OrderedDict[chrom, path_to_bed_for_chrom_walk] ) # Paths works_path: Optional[Path] = None bed_path: Optional[Path] = None bam_path: Optional[Path] = None found_minigraph: bool = False import_links: bool = False # Flag to control link importing db_links: Optional[AsyncGfaDatabase] = None # Instantiated in post_init # Statistics counters segment_count: int = 0 total_segment_length: int = 0 link_count: int = 0 degrees: defaultdict = field(default_factory=lambda: defaultdict(int), repr=False) # segment_id -> degree walks_count: int = 0 max_walk_rank: int = 0 # Max number of segments in a walk sum_rank0_length: int = 0 # Sum of lengths of first segments of walks input_genome_size: int = 0 walks_info: List[Dict[str, Any]] = field(default_factory=list, repr=False) inverted_links_count: int = 0 negative_links_count: int = 0 self_links_count: int = 0 isolated_segments: set = field(default_factory=set, repr=False) # Store segment IDs # Internal operational attributes (initialized in post_init) shared_executor: Optional[ThreadPoolExecutor] = None progress: Optional[Progress] = None line_type_counts: Counter = field(default_factory=Counter) header_gfa_file: Optional[Path] = None stats_file: Optional[Path] = None db_file_path: Optional[Path] = None RE_ORIENTED_SEG_GT_LT: re.Pattern = field(init=False, repr=False) RE_ORIENTED_SEG_PLUS_MINUS: re.Pattern = field(init=False, repr=False) def __post_init__(self) -> None: """ Initializes paths, directories, logger, database, and then triggers GFA processing. """ start_time = datetime.now() if not self.gfa_path.exists() or not self.gfa_path.is_file(): self.logger.error( f"GFA file '{self.gfa_path}' does not exist or is not a file." ) raise FileNotFoundError(f"GFA file not found: {self.gfa_path}") # Fail fast # Determine GFA name # Using removesuffix for cleaner extension removal name_to_process = self.gfa_path.name if name_to_process.endswith(".gfa.gz"): self.gfa_name = name_to_process.removesuffix(".gfa.gz") elif name_to_process.endswith(".gfa"): self.gfa_name = name_to_process.removesuffix(".gfa") else: self.gfa_name = self.gfa_path.stem # Fallback, might include partial extensions self.logger.warning( f"GFA file '{self.gfa_path.name}' has an unusual extension. Using stem: '{self.gfa_name}'.") # Setup working paths self.works_path = self.gfa_path.parent / f"{self.gfa_name}_GraTools-IMPORT" self.works_path.mkdir(exist_ok=True, parents=True) self.bed_path = self.works_path / "bed_files" self.bed_path.mkdir(exist_ok=True, parents=True) self.bam_path = self.works_path / "bam_files" self.bam_path.mkdir(exist_ok=True, parents=True) self.bam_segments_file = self.bam_path / f"{self.gfa_name}.bam" self.header_gfa_file = self.works_path / f"header_{self.gfa_name}.txt" self.stats_file = self.works_path / f"stats_{self.gfa_name}.txt" self.db_file_path = self.works_path / f"links_{self.gfa_name}.db" self.samples_chrom_file_path = self.works_path / "samples_chrom.txt" # Precompile regex for optimization self.RE_ORIENTED_SEG_GT_LT = re.compile(r"([<>])([^<>]+)") self.RE_ORIENTED_SEG_PLUS_MINUS = re.compile(r"(.+?)([+-])") # Initialize ThreadPoolExecutor. self.shared_executor = ThreadPoolExecutor( max_workers=self.threads if self.threads > 0 else None) # None for default workers if self.import_links: self.db_links = AsyncGfaDatabase( db_file=self.db_file_path, # Use the defined path timeout=5000.0 / 1000.0, # Convert ms to seconds if timeout arg is seconds (original was 5000, assuming ms for busy_timeout) # AsyncGfaDatabase timeout is in seconds for aiosqlite.connect ) self.progress = Progress( TextColumn("[bold blue]{task.description}"), BarColumn(), "[progress.percentage]{task.percentage:>3.1f}%", "{task.completed}/{task.total} lines", TimeElapsedColumn(), TimeRemainingColumn(), refresh_per_second=1, # Higher refresh can be costly transient=True, # Progress bar disappears after completion console=shared_console ) self.processed_sample_names = set() # Crรฉer et prรฉ-configurer l'objet segment rรฉutilisable segment_template = AlignedSegment() # Dรฉfinir tous les champs constants UNE SEULE FOIS segment_template.flag = 4 # Unmapped segment_template.reference_id = -1 segment_template.reference_start = 0 segment_template.mapping_quality = 0 segment_template.template_length = 0 # Stocker ce template prรฉ-configurรฉ self.reusable_segment = segment_template # Trigger GFA processing and subsequent steps self.run() # This executes the main parsing and post-processing logic # These are called after self.run() finishes its async and threaded tasks. self.save_header() self.tag_bam() # Assumes self.bam_segments_file is created by parse_gfa (via _parse_segment) self.sort_file_in_place() # sorted file generated by _parse_walk_line self.compute_statistics() # Uses data populated during parsing @staticmethod def _count_lines_fast(gfa_path: Path) -> int: """ Quickly counts the number of lines in a file (plain or gzipped) by reading in large chunks and counting newline characters. """ # Utiliser `gfa_path.name` est plus robuste que de tester la chaรฎne complรจte is_gzipped = gfa_path.name.endswith('.gz') opener = gzip.open if is_gzipped else open count = 0 try: # 'rb' est crucial pour lire des bytes with opener(gfa_path, 'rb') as f: # Lire par blocs de 4MB chunk_size = 4 * 1024 * 1024 while True: chunk = f.read(chunk_size) if not chunk: break count += chunk.count(b'\n') except Exception as e: logging.getLogger("GraTools").warning(f"Fast line count failed for {gfa_path}: {e}") return 0 # ou lever l'exception return count
[docs] def save_header(self) -> None: """Save the GFA header lines (H lines) to a text file.""" if not self.header_gfa_file: self.logger.error("Header GFA file path not set. Cannot save header.") return self.logger.info(f"Saving GFA header to '{self.header_gfa_file.name}'") try: with open(self.header_gfa_file, "w", encoding="utf-8") as f: f.write("\n".join(self.header_gfa)) except IOError as e: self.logger.error(f"Failed to save header file '{self.header_gfa_file}': {e}")
[docs] def sort_file_in_place(self): """ Sort a file in place using Unix commands without creating a temporary copy. The file must be in the format: sample\tchromosome\tstart\tend """ output_path = self.works_path / "samples_chrom.txt" try: # Use Unix commands to sort the file in place # First sort to stdout, then overwrite the original file command = [ 'sort', '-t', '\t', # Tab separator '-k1,1', # Sort by sample (col 1) '-k2,2', # Then by chromosome (col 2) '-k3,3n', # Then by start position (col 3) numerically '-o', str(output_path), # Output to same file str(output_path) ] subprocess.run(command, check=True) except subprocess.CalledProcessError as e: self.logger.warning(f"Error during sorting: {e}") raise
def _process_single_bed(self, bed_path: Path, task_id: TaskID) -> None: """ Sorts a single BED file in place using BedTool. Updates a Rich progress bar task. This method is intended to be run in a separate thread or process. """ self.logger.debug(f"Sorting BED file: {bed_path.name}") try: tmp_path = bed_path.with_suffix(".sorted.bed") sample_bed = BedTool(bed_path.as_posix()).sort(output=tmp_path.as_posix()) tmp_path.replace(bed_path) self.logger.debug(f"Finished sorting BED file: {bed_path.name}") except Exception as e: # Bedtools can raise various errors self.logger.error(f"Error sorting BED file '{bed_path.name}': {e}", exc_info=True) # How to signal error to progress? Maybe update description. if self.progress: self.progress.update(task_id, description=f"[red]Error sorting {bed_path.name}") return # Exit on error for this BED file if self.progress: # Check if progress object exists (it should if called from run()) self.progress.update(task_id, advance=1) async def _read_gfa_lines_async(self): """ Asynchronously reads lines from a GFA file (plain or gzipped) by running the synchronous file I/O in a separate thread. It yields chunks of lines to minimize the overhead of context switching between the reader and the parser, which significantly improves performance. """ is_gzipped = self.gfa_path.name.endswith(".gz") # Determine the correct synchronous open function open_func = gzip.open if is_gzipped else open mode = "rt" if is_gzipped else "r" log_msg = f"Reading {'gzipped' if is_gzipped else 'plain text'} GFA file: {self.gfa_path.name}" self.logger.debug(f"{log_msg} using a dedicated I/O thread and line buffering.") try: # Run the synchronous open function in a separate thread to get the file iterator file_iterator = await asyncio.to_thread(open_func, self.gfa_path, mode, encoding="utf-8") except Exception as e: self.logger.error(f"Failed to open file {self.gfa_path.name} in a thread: {e}") return try: # --- Buffer logic re-introduced here --- lines_buffer = [] # This size can be tuned. 1000-10000 is often a good range. buffer_size = 100000 for line in file_iterator: lines_buffer.append(line) if len(lines_buffer) >= buffer_size: yield lines_buffer # Yield a full chunk of lines lines_buffer = [] # Reset the buffer # Yield any remaining lines in the last, partially-filled buffer if lines_buffer: yield lines_buffer except Exception as e: self.logger.error(f"Error while reading from {self.gfa_path.name} in thread: {e}") finally: # Ensure the file is closed, also in a thread self.logger.debug(f"Closing GFA file: {self.gfa_path.name}") await asyncio.to_thread(file_iterator.close)
[docs] async def parse_gfa(self) -> None: """ Parses the GFA file line by line, processing headers, segments, links, and walks. Segments are written to a BAM file. Links are (optionally) stored in an async SQLite DB. Walk information is used to generate data for BED files, written by AsyncBedWriter. """ # Standard GFA header for output BAM if input GFA has no SQ lines. # Pysam's AlignmentFile requires a header. If GFA S lines imply @SQ, it would be better. # For now, a minimal header. bam_header = { 'HD': {'VN': '1.8', 'SO': 'unsorted'}, # SO: Coordinate or unsorted, depending on S lines 'SQ': [] # Placeholder for sequence dictionary. Ideally, populate from S lines if reference context is known. } # Pre-calculate total lines for progress bar (can be slow for very large files) total_lines = None if not self.disable_progress_flag: try: total_lines = self._count_lines_fast(self.gfa_path) self.logger.info(f"Starting GFA parsing for: '{self.gfa_name}' with {total_lines} lines") except Exception as e: self.logger.warning(f"Could not count lines in GFA file for progress bar: {e}. Progress may be inaccurate.") else: self.logger.info(f"Starting GFA parsing for: '{self.gfa_name}' with unknown lines") # Initialize AsyncBedWriter bed_writer = AsyncBedWriter( bed_dir=self.bed_path, # batch_size=10000, # Tunable parameter progress=self.progress # Pass Rich progress instance if needed by BedWriter ) bed_writer.start() # Start its writer loop # Connect to AsyncGfaDatabase if importing links if self.import_links and self.db_links: await self.db_links.connect() links_batch: List[LinkInfo] = [] # Use List for type hinting, will store LinkInfo link_batch_size = 100000 # How many links to batch before writing to DB bam_segment_batch = [] BATCH_SIZE = 10000 # Add task to Rich Progress parse_task_id = self.progress.add_task( "GFA parsing...", total=total_lines, visible=not self.disable_progress_flag ) self.samples_chrom_file_writer = open(self.samples_chrom_file_path, "w", encoding="utf-8") # Open BAM file for writing segments # pysam.AlignmentFile is synchronous. Writes happen in the main async thread. # If BAM writing is slow, it can block the event loop. # OPTIMIZATION_SUGGESTION: For very high throughput GFA S lines, # consider batching AlignedSegment objects and writing them in a separate thread # using `await asyncio.to_thread(bam_file_out.write, batched_segment)`. try: with (AlignmentFile( self.bam_segments_file.as_posix(), "wb", header=bam_header, threads=self.threads # Use self.threads for multi-threading if pysam supports for write ) as bam_file_out, self.progress): # `with self.progress` starts/stops the Rich Progress context current_line_type = None lines_processed_count = 0 async for lines_chunk in self._read_gfa_lines_async(): for line_content in lines_chunk: lines_processed_count += 1 line_content = line_content.strip() # Strip newline and surrounding whitespace if not line_content or line_content.startswith("#"): # Skip empty or comment lines continue fields = line_content.split("\t") line_type = fields[0] # Mettre ร  jour la description moins souvent aussi if line_type != current_line_type: current_line_type = line_type self.progress.update( parse_task_id, description=f"GFA parsing: {current_line_type}" ) self.line_type_counts[line_type] += 1 match line_type: case "H": # Header line self.header_gfa.append(line_content) # Basic header parsing for version and reference sample (if present) # Example H line: H VN:Z:1.0 RS:Z:ref_sample for field in fields[1:]: if field.startswith("VN:Z:"): self.version = field[5:] elif field.startswith( "RS:Z:"): # Reference Sample tag (custom or specific GFA variant?) self.sample_reference = field[5:] case "S": # Segment line segment_for_bam = self._parse_segment(fields, bam_file_out) # Pass fields list if segment_for_bam: bam_segment_batch.append(copy(segment_for_bam)) if len(bam_segment_batch) >= BATCH_SIZE: for seg in bam_segment_batch: bam_file_out.write(seg) bam_segment_batch.clear() case "L": # Links line self.link_count += 1 if self.import_links and self.db_links: link_info = self._parse_link_line(fields) # Pass fields list if link_info: links_batch.append(link_info) if len(links_batch) >= link_batch_size: await self.db_links.batch_insert_links(list(links_batch)) # Pass a copy links_batch.clear() case "W": # Walk line # W <sample_name> <haplotype_index> <chr_name> <chr_start> <chr_stop> <segment_path> # For GFA2 OrderedGroup (OG lines), parsing would be different. Assuming GFA1 W. parse_result = await self._parse_walk_line(fields, bed_writer) # Pass fields list # if parse_result: # if self.walks_count % 10 == 0: # Every 2 walks # await asyncio.sleep(0) # Other GFA line types (P - Path, C - Containment, E - Edge (GFA2), U - Gap (GFA2), etc.) # Add parsing logic here if needed. For now, they are counted by line_type_counts. # update progress with number of line read on _read_gfe_lines_async self.progress.advance(parse_task_id, 100000) # if self.disable_progress_flag: # self.logger.info(f"Processed {lines_processed_count} lines") if bam_segment_batch: for seg in bam_segment_batch: bam_file_out.write(seg) bam_segment_batch.clear() except Exception as e: self.logger.error(f"Error during GFA parsing main loop: {e}", exc_info=True) # Ensure progress bar is stopped or marked as errored if possible self.progress.update(parse_task_id, description=f"[red]Error during GFA parsing: {e}", completed=lines_processed_count) # Proceed to shutdown, resources might be in an inconsistent state. finally: # Finalize operations after loop (or if error occurred) self.progress.remove_task(parse_task_id) # Or stop if not transient self.logger.info(f"Finished reading GFA file. Processed {lines_processed_count} lines.") if self.found_minigraph: self.logger.warning( "Sample 'MINIGRAPH' (or similar) detected and ignored. " "See Cactus issue https://github.com/ComparativeGenomicsToolkit/cactus/pull/1050 for context." ) # Flush any remaining links to the database if self.import_links and self.db_links and links_batch: self.logger.debug(f"Inserting final batch of {len(links_batch)} links into database.") await self.db_links.batch_insert_links(list(links_batch)) # Shutdown AsyncGfaDatabase (waits for queue, creates indexes, closes connection) if self.db_links: self.logger.info( f"Closing links database and saved to '{self.db_file_path.name if self.db_file_path else 'N/A'}'.") await self.db_links.close() # Shutdown AsyncBedWriter (waits for queue, flushes all buffers, closes files) self.logger.info("Shutting down BED writer...") await bed_writer.shutdown() if self.samples_chrom_file_writer: self.samples_chrom_file_writer.close() self.logger.info(f"Sample-chromosome fragment data saved to '{self.samples_chrom_file_path.name}'.")
[docs] def run(self): """ Synchronous entry point to orchestrate GFA parsing and BED file sorting. Sets up an asyncio event loop and runs the asynchronous `parse_gfa` method. Then, sorts the generated BED files using a ThreadPoolExecutor. """ # Create and configure a new asyncio event loop loop = asyncio.new_event_loop() # uvloop policy is set globally, so new_event_loop() should pick it up. asyncio.set_event_loop(loop) # loop = asyncio.get_event_loop() # Get current event loop (uvloop should be set) if self.shared_executor: # Ensure executor is set for the loop loop.set_default_executor(self.shared_executor) try: # Run the asynchronous GFA parsing loop.run_until_complete(self.parse_gfa()) self.logger.info("Async GFA parsing complete. Proceeding to sort BED files.") # Collect paths of BED files to be sorted # This assumes dict_samples_bed is populated correctly by where AsyncBedWriter saves files. # TODO: AsyncBedWriter saves as self.bed_dir / f"{sample}.bed" # So, we need to find all .bed files in self.bed_path that were potentially created. # A more robust way: AsyncBedWriter could return a list of files it actually wrote/flushed. # Shutdown shared thread executor if self.shared_executor: self.logger.debug("Shutting down shared ThreadPoolExecutor.") self.shared_executor.shutdown(wait=True) # Wait for submitted tasks to complete loop.close() # Only if this `run` method owns the loop lifecycle entirely. # Assuming dict_samples_chrom keys are the samples for which BEDs might exist: beds_to_sort = [] if self.bed_path and self.bed_path.exists(): for sample_name in self.processed_sample_names: # Samples derived from W lines potential_bed_file = self.bed_path / f"{sample_name}.bed" if potential_bed_file.exists() and potential_bed_file.stat().st_size > 0: beds_to_sort.append(potential_bed_file) if not beds_to_sort: self.logger.info("No BED files found or specified for sorting.") else: self.logger.info(f"Found {len(beds_to_sort)} BED files to sort.") sort_task_id = self.progress.add_task( f"Sorting {len(beds_to_sort)} BED files", total=len(beds_to_sort), visible=not self.disable_progress_flag, # Make sure progress bar is visible ) bed_sort_workers = min(self.threads, len(beds_to_sort)) with ThreadPoolExecutor(max_workers=bed_sort_workers) as bed_sort_specific_executor: with self.progress: futures = [ bed_sort_specific_executor.submit( self._process_single_bed, bed_file_path, sort_task_id ) for bed_file_path in beds_to_sort ] # Attendre que toutes les tรขches soumises soient terminรฉes # `wait` bloque jusqu'ร  ce que les futures soient complรฉtรฉes. done, not_done = wait(futures) if not_done: self.logger.warning(f"{len(not_done)} BED sorting tasks did not complete.") for future_not_done in not_done: try: future_not_done.result(timeout=0.1) # Tenter d'obtenir une exception except Exception as e: self.logger.error(f"Task error for uncompleted sort: {e}", exc_info=True) self.progress.remove_task(sort_task_id) self.logger.info("All BED file sorting tasks complete.") except Exception as e: self.logger.error(f"Error during GFA.run execution: {e}", exc_info=True) finally: self.logger.debug("GFA.run() finished.")
@staticmethod def _parse_gfa_walk_path_manually_iter(path_spec: str): """ Parse manuellement une chaรฎne de chemin GFA (ex: '>10<25>150') en utilisant un gรฉnรฉrateur. C'est un remplacement optimisรฉ et plus rapide que re.finditer. Yields: Tuple[str, str]: Un tuple contenant (orientation_char, segment_id). Ex: ('>', '10'), ('<', '25'), ('>', '150') """ if not path_spec: return segment_start_index = 1 for i in range(1, len(path_spec)): char = path_spec[i] if char == '>' or char == '<': # On a trouvรฉ la fin du segment prรฉcรฉdent. segment_id = path_spec[segment_start_index:i] orientation_char = path_spec[segment_start_index - 1] yield orientation_char, segment_id # Le prochain segment commencera juste aprรจs le dรฉlimiteur actuel. segment_start_index = i + 1 # Traiter le tout dernier segment aprรจs la fin de la boucle if segment_start_index <= len(path_spec): last_segment_id = path_spec[segment_start_index:] if last_segment_id: orientation_char = path_spec[segment_start_index - 1] yield orientation_char, last_segment_id async def _parse_walk_line(self, fields: List[str], bed_writer: AsyncBedWriter) -> Optional[str]: """ Parses a GFA Walk line (W line) to extract path information and generate BED-like entries. Updates internal dictionaries related to samples, chromosomes, and segment occurrences. Parameters ---------- fields : List[str] A list of tab-separated fields from a W line. Expected GFA1 format: W <sample> <hap_idx> <chr> <start> <stop> <path_spec> Example path_spec: >seg1A+>seg1B->seg2A+ (GFA1 allows segment names with orientations) >s1>s2<s3 (another common style where orientation is part of path_spec) Returns ------- Optional[str] The sample name if successful, otherwise None. """ # Buffer local pour ce walk bed_lines_batch = [] batch_size = 1000000 # Un batch de taille raisonnable try: # GFA1 W line: W <Samp> <HapIdx> <SeqId> <Start> <End> <Path> # For GFA specification, check official docs. This parsing assumes a common interpretation. _type_line = fields[0] # "W" sample_name = fields[1] haplotype_index_str = fields[2] # Can be numeric or other ID chromosome_name = fields[3] chromosome_start_str = fields[4] # 0-based or 1-based depending on GFA source convention chromosome_stop_str = fields[5] # End coordinate gfa_path_specification = fields[6] # e.g., "s1+s2-s3+" or ">s1<s2>s3" except IndexError: self.logger.warning(f"Malformed W line (not enough fields): '{fields}'. Skipping.") return None if "MINIGRAPH" in sample_name.upper(): # Case-insensitive check self.found_minigraph = True self.logger.debug(f"Skipping W line for 'MINIGRAPH' sample: {fields}") return None # Skip processing for MINIGRAPH sample # 1. Ajouter le nom du sample au set self.processed_sample_names.add(sample_name) # 2. ร‰crire dans le fichier if self.samples_chrom_file_writer: self.samples_chrom_file_writer.write( f"{sample_name}\t{chromosome_name}\t{chromosome_start_str}\t{chromosome_stop_str}\n") # Information for SW tag in BAM (segment_id -> list of "sample;chrom;haplo") sample_chromosome_haplotype_id = f"{sample_name};{chromosome_name};{haplotype_index_str}" oriented_segments_in_path: List[Tuple[str, str]] = [] # 1. Create the C-speed, memory-efficient iterator instead of a list. # matches_iterator = self.RE_ORIENTED_SEG_GT_LT.finditer(gfa_path_specification) oriented_segments_iterator = self._parse_gfa_walk_path_manually_iter(gfa_path_specification) # 2. Handle the first segment separately to process its stats without building a list. try: first_match = next(oriented_segments_iterator) except StopIteration: # This means the path is empty or malformed. self.logger.warning( f"Could not parse any segments from path spec in W line: '{gfa_path_specification}'. Skipping.") return None # 3. Process the stats for the first segment. # We initialize counters here now. self.walks_count += 1 num_segments_in_walk = 0 current_walk_total_length_bp = 0 _dict_segments_samples = self.dict_segments_samples _dict_segments_size = self.dict_segments_size first_seg_id = first_match[1] self.sum_rank0_length += _dict_segments_size.get(first_seg_id, 0) # 4. Define a helper generator to elegantly combine the first element and the rest of the iterator. all_segments_iterator = chain([first_match], oriented_segments_iterator) # Prรฉ-calculer les parties constantes des lignes BED current_segment_start_on_chr = int(chromosome_start_str) bed_field_5_fragment_id = f"{chromosome_start_str}:{chromosome_stop_str}" bed_line_prefix = f"{chromosome_name}\t" bed_line_suffix = f"\t{bed_field_5_fragment_id}\t{haplotype_index_str}\n" # This loop now iterates over the combined generator, not a pre-built list. for orientation_char, segment_id in all_segments_iterator: # Count on the fly. num_segments_in_walk += 1 segment_orientation = "+" if orientation_char == ">" else "-" _dict_segments_samples[segment_id].append(sample_chromosome_haplotype_id) segment_length = _dict_segments_size.get(segment_id, 0) # Warning for 0-length is already there, keep it. if segment_length == 0: self.logger.warning( f"Segment '{segment_id}' in walk for sample '{sample_name}' has zero length " f"or size not found. BED entry may be inaccurate." ) current_walk_total_length_bp += segment_length segment_end_on_chr = current_segment_start_on_chr + segment_length # Crรฉez la ligne BED et envoyez-la immรฉdiatement oriented_seg_id_for_bed = f"{segment_orientation}{segment_id}" bed_line = ( bed_line_prefix + f"{current_segment_start_on_chr}\t{segment_end_on_chr}\t" + oriented_seg_id_for_bed + bed_line_suffix ) # APPEL ASYNCHRONE ร  la nouvelle mรฉthode # await bed_writer.enqueue_single_line(sample_name, bed_line) bed_lines_batch.append(bed_line) # Envoyer le batch quand il est plein if len(bed_lines_batch) >= batch_size: await bed_writer.enqueue(sample_name, bed_lines_batch) bed_lines_batch = [] # Vider le batch current_segment_start_on_chr = segment_end_on_chr # Envoyer le reste du batch ร  la fin du walk if bed_lines_batch: await bed_writer.enqueue(sample_name, bed_lines_batch) # Update total genome size covered by all walks self.input_genome_size += current_walk_total_length_bp self.max_walk_rank = max(self.max_walk_rank, num_segments_in_walk) # Store summary information for this walk (for statistics) self.walks_info.append({ "PathName": f"{sample_name}_{chromosome_name}_{haplotype_index_str}", # Unique path identifier "SequenceLength_bp": current_walk_total_length_bp, "NumberOfSegments": num_segments_in_walk, }) return sample_name def _parse_link_line(self, fields: List[str]) -> Optional[LinkInfo]: """ Parses a GFA Link line (L line) to extract connection information between two segments. Updates link-related statistics. Parameters ---------- fields : List[str] A list of tab-separated fields from an L line. Expected GFA1 format: L <seg1> <orient1> <seg2> <orient2> <overlap_CIGAR> Returns ------- Optional[LinkInfo] A LinkInfo namedtuple containing parsed link data if successful, else None. """ try: _type_line = fields[0] # "L" seg_id_1 = fields[1] orient_char_1 = fields[2] # "+" or "-" seg_id_2 = fields[3] orient_char_2 = fields[4] # "+" or "-" # overlap_cigar = fields[5] # e.g., "0M", "10M", etc. Not used in LinkInfo current def. except IndexError: self.logger.warning(f"Malformed L line (not enough fields): '{fields}'. Skipping.") return None # Convert orientation characters to numeric values (+1 for '+', -1 for '-') # This is a common convention if orientations need to be multiplied or compared numerically. orient_seg_1_numeric = 1 if orient_char_1 == "+" else -1 orient_seg_2_numeric = 1 if orient_char_2 == "+" else -1 # Update link statistics if seg_id_1 == seg_id_2: self.self_links_count += 1 if orient_seg_1_numeric * orient_seg_2_numeric < 0: # e.g., (+1 * -1) = -1 (orientations differ) self.inverted_links_count += 1 if orient_seg_1_numeric == -1 and orient_seg_2_numeric == -1: # Both negative self.negative_links_count += 1 # Update segment degrees self.degrees[seg_id_1] += 1 self.degrees[seg_id_2] += 1 # Remove linked segments from the set of isolated segments self.isolated_segments.discard(seg_id_1) self.isolated_segments.discard(seg_id_2) # Create LinkInfo tuple. The `orient_key` fields seem to duplicate `orient_seg` fields. # If they have a distinct meaning, LinkInfo definition or parsing might need adjustment. return LinkInfo( seg_id_1, orient_seg_1_numeric, orient_seg_1_numeric, # Using numeric orientation for keys too seg_id_2, orient_seg_2_numeric, orient_seg_2_numeric ) def _parse_segment(self, fields: List[str], bam_file_out: AlignmentFile) -> None: """ Parses a GFA Segment line (S line) and writes it as an AlignedSegment to a BAM file. Updates segment-related statistics. Parameters ---------- fields : List[str] A list of tab-separated fields from an S line. Expected GFA1 format: S <seg_id> <sequence> [optional_tags...] Example optional tag: LN:i:<length> (though length is usually derived from sequence) bam_file_out : pysam.AlignmentFile Open BAM file object for writing. """ try: _type_line = fields[0] # "S" seg_id = fields[1] sequence = fields[2] # Sequence string # But here, it seems sequence is expected directly. optional_tags_list = fields[3:] # List of "TAG:TYPE:VALUE" strings except IndexError: self.logger.warning(f"Malformed S line (not enough fields): '{fields}'. Skipping.") return seg_length = len(sequence) # Update statistics self.dict_segments_size[seg_id] = seg_length self.segment_count += 1 self.total_segment_length += seg_length self.isolated_segments.add(seg_id) # Add initially, links will remove it if connected segment_to_write = self.reusable_segment # 1. Donnรฉes du segment segment_to_write.query_name = seg_id segment_to_write.query_sequence = sequence segment_to_write.query_qualities = qualitystring_to_array("*" * seg_length) # 2. CIGAR (qui dรฉpend de la longueur) # Store GFA optional tags in a custom BAM tag, e.g., ZG:Z:<concatenated_gfa_tags> # Original code uses "SU" tag. Let's stick to that. # The "SU" tag seems to store a comma-separated string of the GFA optional tags. if seg_length > 0: segment_to_write.cigartuples = [(0, seg_length)] else: segment_to_write.cigartuples = None # 3. Tags (doivent รชtre effacรฉs et recrรฉรฉs) segment_to_write.tags = [] # EFFACER les tags de l'itรฉration prรฉcรฉdente if optional_tags_list: gfa_tags_str = ",".join(optional_tags_list) segment_to_write.set_tag("SU", gfa_tags_str, value_type='Z') return segment_to_write # Write to BAM file # This is a synchronous call. If it's slow, it blocks the async event loop. # try: # bam_file_out.write(segment_to_write) # except Exception as e: # Catch errors from pysam write (e.g., header issues, malformed segment) # self.logger.error(f"Error writing segment '{seg_id}' to BAM: {e}", exc_info=True)
[docs] def tag_bam(self) -> None: """ Tags the generated BAM segments file with sample walk information (SW tag). This uses the `GratoolsBam` class to perform the tagging. The original BAM segments file is overwritten with the tagged version. """ self.logger.info(f"Initiating BAM tagging for '{self.bam_segments_file.name}'.") if not self.bam_segments_file or not self.bam_segments_file.exists(): self.logger.error("BAM segments file not found. Cannot perform tagging.") return if not self.dict_segments_samples: self.logger.warning("dict_segments_samples is empty. No SW tags will be added.") # Still proceed to index, as an untagged but indexed BAM might be expected. # Instantiate GratoolsBam for tagging operation bam_manager = GratoolsBam( bam_path=self.bam_segments_file, threads=self.threads, logger=self.logger, # Pass parent logger tagging=False, # Indicates that indexing might be needed after tagging disable_progress_flag=self.disable_progress_flag ) # The tag() method is expected to create a new tagged BAM, # then replace the original self.bam_segments_file. # It should also handle re-indexing if `tagging=True` implies index creation/update. try: # The returned path from tag() should be the path to the (now tagged) bam_segments_file updated_bam_path = bam_manager.tag(dict_segments_samples=self.dict_segments_samples, nb_segments=self.segment_count) if updated_bam_path != self.bam_segments_file: self.logger.warning(f"BAM tagging returned a new path '{updated_bam_path}', " f"but expected overwrite of '{self.bam_segments_file.name}'. Check GratoolsBam.tag logic.") except Exception as e: self.logger.error(f"Error during BAM tagging process: {e}", exc_info=True)
def _get_connected_component(self, start_seg: str, cursor: sqlite3.Cursor) -> Set[str]: """ Executes a recursive SQL query via sqlite3 to retrieve all segments connected to `start_seg`, using an existing database cursor. This assumes bi-directional links (if A links to B, B effectively links to A for component finding). Parameters ---------- start_seg : str The starting segment ID for component traversal. cursor : sqlite3.Cursor An active SQLite database cursor to use for the query. Returns ------- Set[str] A set of segment IDs belonging to the same connected component as `start_seg`. """ # The query finds all segments reachable from start_seg. # It considers links in both directions (seg_id_1 -> seg_id_2 and seg_id_2 -> seg_id_1) # by UNIONing them in the recursive part. query = """ WITH RECURSIVE connected_component(segment_id) AS ( VALUES(?) -- Base case: the starting segment UNION -- Recursively find segments linked from current component segments SELECT links.seg_id_2 FROM links JOIN connected_component ON links.seg_id_1 = connected_component.segment_id UNION -- Also consider links in the other direction SELECT links.seg_id_1 FROM links JOIN connected_component ON links.seg_id_2 = connected_component.segment_id ) SELECT segment_id FROM connected_component; """ # OPTIMIZATION_SUGGESTION: For very large graphs, ensure SQLite's query planner handles # this recursive CTE efficiently. Indexing on (seg_id_1, seg_id_2) and (seg_id_2, seg_id_1) # or individual indexes on seg_id_1 and seg_id_2 (as created by AsyncGfaDatabase) is crucial. try: cursor.execute(query, (start_seg,)) rows = cursor.fetchall() return {row[0] for row in rows} # Convert list of tuples to set of IDs except sqlite3.Error as e: self.logger.error(f"SQLite error getting connected component for '{start_seg}': {e}", exc_info=True) return {start_seg} # Return at least the start segment on error def _compute_segment_stats(self) -> dict[str, any]: """ Worker method to compute statistics related to segment lengths. Uses heapq for memory-efficient 'top 5%' calculation. """ self.logger.debug("Starting segment length statistics computation...") stats = {} if not self.dict_segments_size: return stats # Materializing the list is unavoidable for median calculation. # This is the main memory allocation for this specific task. node_lengths = list(self.dict_segments_size.values()) stats["Median Segment Length (bp)"] = float(median(node_lengths)) # Optimization: Use heapq.nlargest to find the top 5% longest segments. # This is more memory and CPU efficient than sorting the entire list for large N. # It has a time complexity of O(N log k) vs O(N log N) for a full sort. num_top_5_percent = max(1, int(0.05 * len(node_lengths))) top_5_percent_nodes_lengths = heapq.nlargest(num_top_5_percent, node_lengths) if top_5_percent_nodes_lengths: stats["Avg Length of Top 5% Longest Segments (bp)"] = float(mean(top_5_percent_nodes_lengths)) stats["Median Length of Top 5% Longest Segments (bp)"] = float(median(top_5_percent_nodes_lengths)) # numpy.histogram is highly efficient for this task and is a reasonable choice. length_bins = [0, 500, 1000, 2000, float("inf")] bin_labels = ["0-499bp", "500-999bp", "1000-1999bp", "2000+bp"] hist_counts, _ = histogram(node_lengths, bins=length_bins) stats["Segment Length Distribution"] = dict(zip(bin_labels, map(int, hist_counts))) self.logger.debug("Finished segment length statistics.") return stats def _compute_similarity_stats(self) -> dict[str, any]: """ Worker method to compute segment sharing (similarity) and occurrence (depth) statistics. """ self.logger.debug("Starting segment similarity/depth statistics computation...") stats = {} if not self.dict_segments_samples: return stats # Materializing lists is necessary for median calculation. # Using generators here would not provide a memory benefit due to the median requirement. similarities = [] # Number of unique samples per segment depths = [] # Total occurrences of a segment across all walks for sample_occurrences_list in self.dict_segments_samples.values(): depths.append(len(sample_occurrences_list)) # Reconstruire les noms d'รฉchantillons uniques directement depuis les chaรฎnes unique_samples = {occ.split(";", 1)[0] for occ in sample_occurrences_list} similarities.append(len(unique_samples)) if similarities: similarities_np = np.array(similarities) stats["Avg Unique Samples per Segment (Similarity Mean)"] = float(np.mean(similarities_np)) stats["Median Unique Samples per Segment (Similarity Median)"] = float(np.median(similarities_np)) # stdev requires at least two data points stats["StdDev Unique Samples per Segment (Similarity Std)"] = float(np.std(similarities_np)) if len( similarities) > 1 else 0.0 if depths: depths_np = np.array(depths) stats["Avg Occurrences per Segment (Depth Mean)"] = float(np.mean(depths_np)) stats["Median Occurrences per Segment (Depth Median)"] = float(np.median(depths_np)) stats["StdDev Occurrences per Segment (Depth Std)"] = float(np.std(depths_np)) if len(depths) > 1 else 0.0 self.logger.debug("Finished segment similarity/depth statistics.") return stats def _compute_connectivity_stats(self) -> dict[str, any]: """ Worker method to compute graph connectivity statistics by querying the SQLite database. This is often the most time-consuming statistics task. """ self.logger.debug("Starting connectivity statistics computation...") stats = { "Number of Connected Components (CCs)": 0, "Largest CC Size (bp)": 0, "Number of Disconnected CCs (excluding largest)": 0, "Total Length of Disconnected CCs (bp)": 0 } if not (self.import_links and self.db_links and self.db_links.db_file and self.db_links.db_file.exists()): self.logger.info("Link importing disabled or DB not found; skipping connectivity stats.") return stats conn_sqlite = None try: # Each thread requires its own database connection. # Use read-only mode for safety and potential performance gains. db_uri = f"file:{self.db_links.db_file.as_posix()}?mode=ro" conn_sqlite = sqlite3.connect(db_uri, uri=True, timeout=5.0) cursor = conn_sqlite.cursor() cursor.execute("PRAGMA query_only = ON;") all_components_nodes = [] visited_segments = set() # Iterate over all known segment IDs to find all connected components for seg_id in self.dict_segments_size.keys(): if seg_id not in visited_segments: component = self._get_connected_component(seg_id, cursor) if component: all_components_nodes.append(component) visited_segments.update(component) if not all_components_nodes: return stats stats["Number of Connected Components (CCs)"] = len(all_components_nodes) # Calculate the size in base pairs of each component component_lengths_bp = [sum(self.dict_segments_size.get(s, 0) for s in comp) for comp in all_components_nodes] if component_lengths_bp: largest_cc_size_bp = max(component_lengths_bp) stats["Largest CC Size (bp)"] = largest_cc_size_bp # All components that are not the largest one disconnected_lengths = [length for length in component_lengths_bp if length < largest_cc_size_bp] stats["Number of Disconnected CCs (excluding largest)"] = len(disconnected_lengths) stats["Total Length of Disconnected CCs (bp)"] = sum(disconnected_lengths) except sqlite3.Error as e: self.logger.error(f"SQLite error during connectivity stats: {e}", exc_info=True) finally: if conn_sqlite: conn_sqlite.close() self.logger.debug("Finished connectivity statistics.") return stats
[docs] def compute_statistics(self) -> dict[str, any]: """ Orchestrates the computation of GFA graph statistics in parallel and saves the results. This method acts as a dispatcher, running heavy calculations in separate threads. """ self.logger.info("Computing GFA graph statistics in parallel...") final_stats = {} # Define the computational tasks to be run in parallel tasks_to_run = { "Segment Statistics": self._compute_segment_stats, "Segment Sharing & Depth": self._compute_similarity_stats, "Graph Structure": self._compute_connectivity_stats, # Usually the longest task } total_steps = len(tasks_to_run) + 1 # +1 for final aggregation and formatting stats_task_id = self.progress.add_task( "[bold blue]Computing graph statistics...", total=total_steps, visible=not self.disable_progress_flag ) with self.progress, ThreadPoolExecutor(max_workers=self.threads) as executor: # Submit all tasks to the thread pool future_to_category = {executor.submit(func): category for category, func in tasks_to_run.items()} # Process results as they complete for future in as_completed(future_to_category): category = future_to_category[future] try: result_data = future.result() final_stats[category] = result_data self.logger.info(f"โœ“ Statistics for '{category}' computed successfully.") except Exception as e: self.logger.error(f"Error computing stats for category '{category}': {e}", exc_info=True) finally: # Advance the progress bar for each completed task self.progress.update(stats_task_id, advance=1) # --- Final Aggregation and Fast Calculations (run in the main thread) --- self.logger.info("Aggregating final statistics...") # 1. Graph Overview (fast, no thread needed) final_stats["Graph Overview"] = { "GFA File Name": self.gfa_name, "GFA Version": self.version, "Total Segments (S lines)": self.segment_count, "Total Links (L lines)": self.link_count, "Total Walks (W lines)": self.walks_count, "Unique Samples in Walks": len(self.processed_sample_names), } # 2. Basic Segment and Link Stats (fast) avg_segment_length = (self.total_segment_length / self.segment_count if self.segment_count else 0.0) final_stats.setdefault("Segment Statistics", {})["Total Segment Length (bp)"] = self.total_segment_length final_stats.setdefault("Segment Statistics", {})["Average Segment Length (bp)"] = avg_segment_length avg_degree = (2 * self.link_count / self.segment_count if self.segment_count else 0.0) max_degree = max(self.degrees.values()) if self.degrees else 0 final_stats["Link Statistics"] = { "Max Segment Degree": max_degree, "Average Segment Degree": avg_degree, "Self-Links (S1 -> S1)": self.self_links_count, "Inverted Links (S1+ -> S2-)": self.inverted_links_count, "Both Negative Links (S1- -> S2-)": self.negative_links_count, } # 3. Path Stats (fast) graph_size_bp = self.total_segment_length compress_ratio = self.input_genome_size / graph_size_bp if graph_size_bp else 0.0 final_stats["Path (Walk) Statistics"] = { "Total Length of All Paths (bp, input_genome_size)": self.input_genome_size, "Graph Compression Ratio (input_genome_size / graph_size_bp)": compress_ratio, "Max Segments in a Single Walk": self.max_walk_rank, "Sum of First Segment Lengths in Walks": self.sum_rank0_length, } # 4. Complete Graph Structure Stats (fast) graph_density = ( 2 * self.link_count / (self.segment_count * (self.segment_count - 1)) if self.segment_count > 1 else 0.0) dead_ends_count = sum(1 for degree in self.degrees.values() if degree == 1) final_stats.setdefault("Graph Structure", {}).update({ "Graph Density": graph_density, "Segments/Links Ratio": self.segment_count / self.link_count if self.link_count else 0.0, "Dead-End Segments (degree 1)": dead_ends_count, "Isolated Segments (degree 0)": len(self.isolated_segments), }) # Update progress bar for the final aggregation step self.progress.update(stats_task_id, advance=1, description="[green]Formatting & Saving") # --- Formatting and Writing to Output File --- flat_stats_for_df = [] # Sort categories for consistent output order for category in sorted(final_stats.keys()): metrics_dict = final_stats[category] if isinstance(metrics_dict, dict): # Sort metrics within each category for consistent output order for metric_name in sorted(metrics_dict.keys()): value = metrics_dict[metric_name] flat_stats_for_df.append((category, metric_name, value)) df_stats = DataFrame(flat_stats_for_df, columns=["Category", "Metric", "Value"]) def format_value(x): if isinstance(x, dict): return ", ".join(f"{k}: {v}" for k, v in x.items()) if isinstance(x, float): return f"{x:.2e}" if (0 < abs(x) < 1e-3 or abs(x) > 1e6) else f"{x:.2f}" if isinstance(x, int): return f"{x:,}" return str(x) df_stats["Value"] = df_stats["Value"].apply(format_value) if self.stats_file: try: df_stats.to_csv(self.stats_file, sep="\t", index=False) self.logger.info(f"Saved GFA statistics to '{self.stats_file.name}'") except IOError as e: self.logger.error(f"Failed to save statistics CSV '{self.stats_file}': {e}") self.progress.remove_task(stats_task_id) return final_stats
[docs] @dataclass class GratoolsBam: """ Handles operations related to BAM files in the GraTools context, such as indexing, extracting segment information, tagging segments, and performing various analyses (core/dispensable ratio, depth statistics, etc.). Attributes ---------- bam_path : Path Path to the BAM file. threads : int, optional Number of threads for BAM operations (e.g., reading, indexing). Defaults to 1. logger : logging.Logger Logger instance. Defaults to a logger named "GraTools". suffix : Optional[str], optional Suffix to append to output filenames generated by analyses. Defaults to None. works_path : Optional[Path], optional Working directory path for saving output files. Defaults to None (uses BAM parent dir). gfa_name : Optional[str], optional Name of the associated GFA file (used for naming output files). Defaults to None. tagging : bool, optional If True, indicates that operations might modify tags, potentially requiring re-indexing. Used by `index_bam` to decide if indexing is needed. Defaults to False. progress : Optional[Progress] Rich Progress instance for displaying progress. Auto-initialized. disable_progress_flag Optional[bool], optional Flag to disable progress bar. Defaults to False. """ bam_path: Path threads: int = 1 logger: logging.Logger = field( default_factory=lambda: logging.getLogger("GraTools") ) suffix: Optional[str] = None # For output file naming works_path: Optional[Path] = None # Directory for outputs gfa_name: Optional[str] = None # For naming consistency tagging: bool = False # If operations involve modifying tags, affects indexing logic disable_progress_flag: bool = False # Internal attribute for progress bar progress: Optional[Progress] = field(init=False, repr=False) def __post_init__(self) -> None: """Initialize progress tracker and index the BAM file if necessary.""" self.progress = Progress( TextColumn("[bold blue]{task.description}"), BarColumn(), "[progress.percentage]{task.percentage:>3.1f}%", TextColumn("{task.completed}/{task.total} segs"), TimeElapsedColumn(), TimeRemainingColumn(), refresh_per_second=1, # Reduced refresh rate for potentially long BAM ops transient=True, ) if not self.bam_path.exists(): self.logger.error(f"BAM file does not exist: {self.bam_path}") # raise FileNotFoundError(f"BAM file not found: {self.bam_path}") # Or handle gracefully return self.index_bam() # Index on initialization if needed # Ensure works_path is set, default to BAM's parent directory if not provided if self.works_path is None: self.works_path = self.bam_path.parent self.works_path.mkdir(parents=True, exist_ok=True) # Ensure it exists # Ensure gfa_name is set, derive from bam_path if not provided if self.gfa_name is None: # Try to infer from BAM name, e.g., if bam is "mygfa.bam", gfa_name is "mygfa" self.gfa_name = self.bam_path.name.removesuffix(".bam") # Ensure suffix is an empty string if None, for consistent filename concatenation if self.suffix is None: self.suffix = ""
[docs] def index_bam(self) -> None: """ Indexes the BAM file using `pysam.index` if the index is missing or outdated. The index file will have a '.bai' or '.crai' extension depending on BAM format. """ if not self.bam_path.exists(): self.logger.warning(f"Cannot index BAM file: '{self.bam_path}' does not exist.") return # Pysam handles .bai for BAM and .crai for CRAM automatically. # Common index extension for BAM is .bai. bam_index_path_bai = self.bam_path.with_suffix(self.bam_path.suffix + ".bai") # e.g. file.bam.bai bam_index_path_crai = self.bam_path.with_suffix(self.bam_path.suffix + ".crai") # e.g. file.cram.crai # Determine if index exists (either .bai or .crai) index_exists = bam_index_path_bai.exists() or bam_index_path_crai.exists() actual_index_path = bam_index_path_bai if bam_index_path_bai.exists() else bam_index_path_crai needs_indexing = False if not index_exists and self.tagging: self.logger.info( f"Indexing BAM file '{self.bam_path.name}': No existing index found." ) needs_indexing = True # If an index exists, check if BAM file is newer (mtime check) # This check only makes sense if `self.tagging` is True, implying the BAM might have changed. # Otherwise, if BAM is static, this check is less relevant after first indexing. elif self.tagging and actual_index_path.exists() and \ self.bam_path.stat().st_mtime > actual_index_path.stat().st_mtime: self.logger.info( f"Re-indexing BAM file '{self.bam_path.name}': BAM file is newer than its index." ) needs_indexing = True if needs_indexing: try: # pysam.index will create file.bam.bai or file.cram.crai index(self.bam_path.as_posix(), threads=self.threads) self.logger.info(f"Successfully indexed BAM file: '{self.bam_path.name}'.") except Exception as e: # pysam can raise various errors (RuntimeError, ValueError) self.logger.error(f"Error indexing BAM file '{self.bam_path.name}': {e}", exc_info=True) else: self.logger.debug(f"BAM file '{self.bam_path.name}' index is up-to-date or tagging is false.")
def _get_total_segments_in_bam(self) -> int: """ Counts the total number of alignments (segments) in the BAM file. Returns ------- int The total number of segments. Returns 0 if file cannot be read or is empty. """ if not self.bam_path.exists(): return 0 try: with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file: total_segments = bam_file.count(until_eof=True) # More direct pysam way return total_segments except Exception as e: self.logger.error(f"Error counting segments in BAM '{self.bam_path.name}': {e}") return 0
[docs] def build_segments(self, list_segments: Optional[List[str]] = None) -> Tuple[List[str], defaultdict, defaultdict]: """ Extracts specified segments from a BAM file and reconstructs their GFA S-line representation. Also populates dictionaries for segment samples and sequences. Parameters ---------- list_segments : Optional[List[str]], optional A list of segment IDs (query_name in BAM) to extract. If None or empty, this method might process all segments or return empty results, depending on intended behavior (current pysam.view call implies it needs a list). Returns ------- Tuple[List[str], defaultdict[str, List[str]], defaultdict[str, str]] - gfa_s_lines_list: List of strings, each a GFA S-line. - dict_seg_samples: defaultdict mapping segment ID to list of "sample;chrom;haplo" strings from SW tag. - dict_seg_sequence: defaultdict mapping segment ID to its sequence. """ gfa_s_lines_list: List[str] = [] dict_seg_samples: defaultdict[str, List[str]] = defaultdict(list) dict_seg_sequence: defaultdict[str, str] = defaultdict(str) if not list_segments: self.logger.warning("Build_segments called with no segment IDs. Returning empty results.") return gfa_s_lines_list, dict_seg_samples, dict_seg_sequence # Temporary files are used to pass the list of segment IDs to `pysam view`. # OPTIMIZATION_SUGGESTION: If `list_segments` is very large, writing to temp file is fine. # For smaller lists, iterating the BAM in Python and filtering by `segment.query_name in set(list_segments)` # might avoid `pysam view` subprocess overhead, but `pysam view` is highly optimized. # The choice depends on typical `list_segments` size and BAM size. # Ensure bam_path's parent exists for temp files if works_path isn't explicitly set for temp dir temp_dir = self.works_path if self.works_path else self.bam_path.parent try: with tempfile.NamedTemporaryFile( mode="w+", dir=temp_dir, suffix=".txt", delete=False ) as list_seg_file, tempfile.NamedTemporaryFile( mode="wb", dir=temp_dir, suffix=".bam", delete=False ) as bam_temp_file: temp_seg_list_path = Path(list_seg_file.name) temp_bam_path = Path(bam_temp_file.name) # Write segment IDs to the temporary text file for seg_id in list_segments: list_seg_file.write(f"{seg_id}\n") list_seg_file.flush() # Ensure data is written before pysam view reads it self.logger.debug( f"Extracting {len(list_segments)} segments using pysam view. List: {temp_seg_list_path}, Temp BAM: {temp_bam_path}") # Use pysam.view to extract these segments into a temporary BAM file. # -N option: read names in file (the list_seg_file.name) # -hb option: output BAM, include header # The input BAM is self.bam_path # Output is stdout, captured as bytes # REFACTOR_SUGGESTION: pysam.view can output directly to a filename: # pysam.view("-hbN", temp_seg_list_path.as_posix(), "-o", temp_bam_path.as_posix(), self.bam_path.as_posix(), catch_stdout=False) # This avoids reading all output into memory if it's large. bam_view_output_bytes = view("-hbN", temp_seg_list_path.as_posix(), self.bam_path.as_posix()) bam_temp_file.write(bam_view_output_bytes) bam_temp_file.flush() # Read the temporary BAM file containing only the desired segments with AlignmentFile( temp_bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file_filtered: for segment in bam_file_filtered.fetch(until_eof=True): seg_id = segment.query_name sequence = segment.query_sequence # Reconstruct GFA S-line. SU tag contains GFA optional tags. su_tag_val = f"\t{segment.get_tag('SU')}" if segment.has_tag('SU') else "" gfa_s_line = f"S\t{seg_id}\t{sequence}{su_tag_val}" gfa_s_lines_list.append(gfa_s_line) # Populate segment samples dictionary from SW tag if segment.has_tag('SW'): sw_tag_val = segment.get_tag('SW') # SW tag format: "sample1;chrom1;hap1,sample2;chrom2;hap2,..." # The split by ';'[:-1] seems to intend removing a trailing empty string if # the tag ends with a comma/semicolon. Robust split: dict_seg_samples[seg_id] = [s.strip() for s in sw_tag_val.split(',') if s.strip()] # Populate segment sequence dictionary dict_seg_sequence[seg_id] = sequence return gfa_s_lines_list, dict_seg_samples, dict_seg_sequence except Exception as e: self.logger.error(f"Error in build_segments: {e}", exc_info=True) return [], defaultdict(list), defaultdict(str) # Return empty on error finally: # Clean up temporary files if 'temp_seg_list_path' in locals() and temp_seg_list_path.exists(): temp_seg_list_path.unlink() if 'temp_bam_path' in locals() and temp_bam_path.exists(): temp_bam_path.unlink()
[docs] def tag(self, dict_segments_samples: Dict[str, List[int]], nb_segments: int) -> Path: """ Adds or updates the 'SW' (Sample Walks) tag to segments in the BAM file. This version uses an integer-to-string mapping for walk paths to improve performance. The SW tag stores a comma-separated list of "sample;chromosome;haplotype" strings indicating which walks/paths contain the segment. The original BAM file is overwritten with the tagged version. Parameters ---------- dict_segments_samples : Dict[str, List[int]] Dictionary mapping segment IDs (query_name) to a list of integer IDs representing the walk paths. nb_segments : int The total number of segments in the BAM file. Returns ------- Path The path to the (now tagged and re-indexed) BAM file. Raises ------ FileNotFoundError If the input BAM file does not exist. Exception If errors occur during BAM reading, writing, or renaming. """ if not self.bam_path.exists(): self.logger.error(f"Input BAM file for tagging not found: {self.bam_path}") raise FileNotFoundError(f"Input BAM file for tagging not found: {self.bam_path}") # Path for the temporary tagged BAM file temp_tagged_bam_path = self.bam_path.with_name(f"{self.bam_path.stem}_tagged_temp.bam") num_segments_to_tag = nb_segments self.logger.info( f"Tagging {num_segments_to_tag} segments in BAM file: '{self.bam_path.name}'. " ) WRITE_BUFFER_SIZE = 1000 segment_buffer = [] try: with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=int(self.threads/2) ) as bam_file_in, AlignmentFile( temp_tagged_bam_path.as_posix(), "wb", # Write BAM template=bam_file_in, # Use header from input BAM threads=int(self.threads/2), ) as bam_file_out: # Progress bar for tagging total_segments_in_bam = num_segments_to_tag # Efficient count tag_task_id = self.progress.add_task( f"Tagging BAM '{self.bam_path.name}'", total=total_segments_in_bam, visible=not self.disable_progress_flag ) with self.progress: # Context manager for Rich progress for segment in bam_file_in.fetch(until_eof=True): self.progress.advance(tag_task_id) segment_id = segment.query_name # Construct SW tag string: "samp1;chr1;hap1,samp2;chr2;hap2,..." sw_tag_value = ",".join(dict_segments_samples[segment_id]) # Set SW tag (string type 'Z'). Replace if it exists. segment.set_tag("SW", sw_tag_value, value_type='Z', replace=True) segment_buffer.append(copy(segment)) if len(segment_buffer) >= WRITE_BUFFER_SIZE: for buffered_segment in segment_buffer: bam_file_out.write(buffered_segment) segment_buffer.clear() if segment_buffer: for buffered_segment in segment_buffer: bam_file_out.write(buffered_segment) segment_buffer.clear() # Replace original BAM with the tagged BAM temp_tagged_bam_path.replace(self.bam_path) # Atomically replaces if on same filesystem self.logger.info(f"BAM file '{self.bam_path.name}' successfully tagged.") # Mark that tagging occurred, so index_bam can re-index self.tagging = True self.index_bam() # Re-index the newly tagged BAM file except Exception as e: self.logger.error(f"Error during BAM tagging for '{self.bam_path.name}': {e}", exc_info=True) # Clean up temporary file if it exists on error if temp_tagged_bam_path.exists(): temp_tagged_bam_path.unlink(missing_ok=True) raise # Re-raise the exception after cleanup attempt return self.bam_path
[docs] def pan_ratio( self, nb_samples_gfa: int, input_as_number: bool, shared_min_cutoff: int = 1, specific_max_cutoff: Optional[int] = None, # can be None filter_min_len: int = 1, ) -> None: """ Analyzes segments in the BAM file to determine core (shared) and dispensable (specific) ratios based on the number of samples a segment is found in. Saves results to a CSV file. Parameters ---------- nb_samples_gfa : int Total number of unique samples present in the GFA (used for percentage calculation). input_as_number : bool If True, `shared_min_cutoff` and `specific_max_cutoff` are treated as absolute counts of samples. If False, they are treated as percentages of `nb_samples_gfa`. shared_min_cutoff : int, optional Minimum number/percentage of samples a segment must be in to be considered "shared" (core). Defaults to 1. specific_max_cutoff : Optional[int], optional Maximum number/percentage of samples a segment can be in to be considered "specific" (dispensable). If None, specific analysis might be skipped or use a default (e.g., 1 if input_as_number). Defaults to None. filter_min_len : int, optional Minimum length (bp) for a segment to be included in the filtered analysis. Defaults to 1 (no length filter). """ # Validate and adjust cutoffs based on input_as_number if not input_as_number: # Percentages shared_min_abs = round((shared_min_cutoff * nb_samples_gfa) / 100) specific_max_abs = round(( specific_max_cutoff * nb_samples_gfa) / 100) if specific_max_cutoff is not None else 1 # Default to 1 if specific_max is None in % mode self.logger.info( f"Cutoffs (percentage input): Shared_min={shared_min_cutoff}% ({shared_min_abs} samples), " f"Specific_max={specific_max_cutoff}% ({specific_max_abs} samples if specified)." ) else: # Absolute numbers shared_min_abs = shared_min_cutoff specific_max_abs = specific_max_cutoff if specific_max_cutoff is not None else 1 # Default for absolute self.logger.info( f"Cutoffs (absolute number input): Shared_min={shared_min_abs} samples, " f"Specific_max={specific_max_abs} samples if specified." ) if specific_max_abs is not None and shared_min_abs < specific_max_abs: self.logger.warning( # Warning, not error, as it might be intentional for some definitions f"shared_min_cutoff ({shared_min_abs}) is less than specific_max_cutoff ({specific_max_abs}). " "This means some segments could be classified as both shared and specific. " "Ensure this is the intended definition." ) if shared_min_abs > nb_samples_gfa or \ (specific_max_abs is not None and specific_max_abs > nb_samples_gfa): self.logger.error( f"Cutoff values exceed total GFA samples ({nb_samples_gfa}). " f"Shared_min={shared_min_abs}, Specific_max={specific_max_abs}. Adjust parameters." ) return self.logger.info( f"Analyzing core/dispensable ratio for {nb_samples_gfa} GFA samples. " f"Effective cutoffs: Shared_min >= {shared_min_abs}, Specific_max <= {specific_max_abs}, Min_Length_Filter = {filter_min_len}bp." ) category_counts = Counter() # Keys: "shared", "specific", "shared_filtered", "specific_filtered", "length_filtered_out" # dico_count_samples maps num_samples_per_segment -> count_of_such_segments total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: self.logger.warning("BAM file is empty or unreadable. Cannot perform core/dispensable analysis.") return core_disp_task_id = self.progress.add_task( "Core/Dispensable Ratio Analysis", total=total_segments_in_bam ) with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(core_disp_task_id) seg_length = segment.query_length if not segment.has_tag("SW"): self.logger.warning( f"Segment '{segment.query_name}' missing SW tag. Cannot determine sample count. Skipping.") category_counts["missing_sw_tag"] += 1 continue sw_tag_value = segment.get_tag("SW") # Samples are "sample_id;chrom;haplotype", extract unique sample_ids unique_samples_for_segment = set(s.split(";")[0] for s in sw_tag_value.split(',') if s) num_samples_for_this_segment = len(unique_samples_for_segment) # Classify segment (unfiltered by length first) is_shared = num_samples_for_this_segment >= shared_min_abs is_specific = specific_max_abs is not None and num_samples_for_this_segment <= specific_max_abs if is_shared: category_counts["shared_raw"] += 1 if is_specific: # Note: can be both shared and specific if cutoffs overlap category_counts["specific_raw"] += 1 # Apply length filter if seg_length >= filter_min_len: # Segment passes length filter if is_shared: category_counts["shared_len_filtered"] += 1 if is_specific: category_counts["specific_len_filtered"] += 1 else: # Segment does not pass length filter category_counts["length_filtered_out"] += 1 # Calculate percentages total_segments_processed = total_segments_in_bam - category_counts["missing_sw_tag"] segments_passing_length_filter = total_segments_processed - category_counts["length_filtered_out"] results_data = [] if total_segments_processed > 0: results_data.extend([ ("Shared (Core) - Raw", category_counts["shared_raw"], total_segments_processed, (category_counts["shared_raw"] / total_segments_processed) * 100), ("Specific (Dispensable) - Raw", category_counts["specific_raw"], total_segments_processed, (category_counts["specific_raw"] / total_segments_processed) * 100), ]) if segments_passing_length_filter > 0: results_data.extend([ (f"Shared (Core) - Filtered (Length >= {filter_min_len}bp)", category_counts["shared_len_filtered"], segments_passing_length_filter, (category_counts["shared_len_filtered"] / segments_passing_length_filter) * 100), (f"Specific (Dispensable) - Filtered (Length >= {filter_min_len}bp)", category_counts["specific_len_filtered"], segments_passing_length_filter, (category_counts["specific_len_filtered"] / segments_passing_length_filter) * 100), ]) results_data.append( ("Segments Filtered Out by Length", category_counts["length_filtered_out"], total_segments_processed, ( category_counts[ "length_filtered_out"] / total_segments_processed) * 100 if total_segments_processed > 0 else 0)) if category_counts["missing_sw_tag"] > 0: results_data.append( ("Segments Missing SW Tag (Skipped)", category_counts["missing_sw_tag"], total_segments_in_bam, (category_counts["missing_sw_tag"] / total_segments_in_bam) * 100)) df_results = DataFrame(results_data, columns=["Category", "Count", "Total Relevant Segments", "Percentage"]) # Define output CSV file path # Suffix construction needs care if specific_max_cutoff is None spec_max_str = str(specific_max_cutoff) if specific_max_cutoff is not None else "N_A" csv_filename = ( f"{self.gfa_name}_core_disp_ratio" f"_shared{shared_min_cutoff}{'pct' if not input_as_number else ''}" f"_spec{spec_max_str}{'pct' if not input_as_number else ''}" f"_len{filter_min_len}{self.suffix}.csv" ) output_csv_path = self.works_path / csv_filename try: df_results.to_csv(output_csv_path, index=False, float_format="%.2f") self.logger.info(f"Core/Dispensable ratio analysis saved to: {output_csv_path}") # Summary shared_console.rule("[bold green]Summary[/bold green]") shared_console.print( f"[cyan]Total segments in GFA:[/cyan] {total_segments_in_bam:,}" ) shared_console.print( f"[cyan]Total segments analyzed:[/cyan] {total_segments_processed:,}" ) shared_console.print( f"[cyan]Total segments passing length filter (โ‰ฅ {filter_min_len}bp):[/cyan] {segments_passing_length_filter:,} " f"({(segments_passing_length_filter / total_segments_processed) * 100:.2f}%)" ) # Display formatted table using Rich table = Table( title=f"[bold magenta]Core vs. Dispensable Segments[/bold magenta] โ€” [cyan]{self.gfa_name}[/cyan]", box=box.ROUNDED, show_lines=False, ) table.add_column("Category", style="bold cyan", justify="left") table.add_column("Count", style="green", justify="right") table.add_column("Total", style="yellow", justify="right") table.add_column("Percentage", style="bold green", justify="right") for _, row in df_results.iterrows(): table.add_row( str(row["Category"]), f"{int(row['Count']):,}", f"{int(row['Total Relevant Segments']):,}", f"{row['Percentage']:.2f}%" ) shared_console.print(table) except IOError as e: self.logger.error(f"Failed to save Core/Dispensable ratio CSV to '{output_csv_path}': {e}")
[docs] def depth_nodes_stat(self, nb_samples_gfa: int, filter_min_len: int = 1) -> None: """ Calculates and displays statistics about segment depth (number of unique samples a segment is found in). Outputs results to console and a CSV file. Parameters ---------- nb_samples_gfa : int Total number of unique samples in the GFA, used for context if needed (not directly in calcs here). filter_min_len : int, optional Minimum length (bp) for a segment to be included in the filtered depth analysis. Defaults to 1 (no effective length filter). """ self.logger.info( f"Calculating node depth statistics for GFA with {nb_samples_gfa} total samples. " f"Length filter for 'Filtered_Counts': >= {filter_min_len}bp." ) # depth_counts_raw: maps segment_depth (num_samples) -> count of segments with this depth depth_counts_raw = defaultdict(int) # depth_counts_filtered: similar, but only for segments passing length filter depth_counts_filtered = defaultdict(int) # Other counters from original code, ensure their purpose is clear or remove if unused by this method's output # counter = Counter() # Was: "shared_filtered", "specific_filtered", "filtered" - seems mixed with core_dispensable logic total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: self.logger.warning("BAM file is empty or unreadable. Cannot perform node depth analysis.") return depth_stat_task_id = self.progress.add_task( "Node Depth Statistics Analysis", total=total_segments_in_bam ) segments_missing_sw_tag = 0 with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(depth_stat_task_id) if not segment.has_tag("SW"): segments_missing_sw_tag += 1 continue # Skip segments without SW tag for depth calculation sw_tag_value = segment.get_tag("SW") unique_samples_for_segment = set(s.split(";")[0] for s in sw_tag_value.split(',') if s) current_segment_depth = len(unique_samples_for_segment) depth_counts_raw[current_segment_depth] += 1 seg_length = segment.query_length if seg_length >= filter_min_len: depth_counts_filtered[current_segment_depth] += 1 if segments_missing_sw_tag > 0: self.logger.warning(f"{segments_missing_sw_tag} segments were skipped due to missing SW tag.") # Prepare data for DataFrame df_raw = DataFrame(list(depth_counts_raw.items()), columns=["Segment_Depth", "Raw_Count"]) df_filtered = DataFrame(list(depth_counts_filtered.items()), columns=["Segment_Depth", f"Filtered_Count_Len>={filter_min_len}"]) # Merge the two dataframes on Segment_Depth df_stats = merge(df_raw, df_filtered, on="Segment_Depth", how="outer").fillna(0) # Calculate percentages total_raw_segments_counted = df_stats["Raw_Count"].sum() total_filtered_segments_counted = df_stats[f"Filtered_Count_Len>={filter_min_len}"].sum() if total_raw_segments_counted > 0: df_stats["Raw_Percentage"] = (df_stats["Raw_Count"] / total_raw_segments_counted) * 100 else: df_stats["Raw_Percentage"] = 0.0 if total_filtered_segments_counted > 0: df_stats[f"Filtered_Percentage_Len>={filter_min_len}"] = \ (df_stats[f"Filtered_Count_Len>={filter_min_len}"] / total_filtered_segments_counted) * 100 else: df_stats[f"Filtered_Percentage_Len>={filter_min_len}"] = 0.0 # Sort by Segment_Depth df_stats_sorted = df_stats.sort_values(by="Segment_Depth").reset_index(drop=True) # Convert counts to int type for cleaner display df_stats_sorted["Raw_Count"] = df_stats_sorted["Raw_Count"].astype(int) df_stats_sorted[f"Filtered_Count_Len>={filter_min_len}"] = df_stats_sorted[ f"Filtered_Count_Len>={filter_min_len}"].astype(int) # Output to CSV csv_filename = f"{self.gfa_name}_node_depth_stats_len{filter_min_len}{self.suffix}.csv" output_csv_path = self.works_path / csv_filename try: df_stats_sorted.to_csv(output_csv_path, index=False, float_format="%.2f") self.logger.info(f"Node depth statistics saved to: {output_csv_path}") except IOError as e: self.logger.error(f"Failed to save node depth statistics CSV to '{output_csv_path}': {e}") try: # Dynamic title table = Table( title=f"[bold magenta]Node Depth Statistics[/bold magenta] โ€” [cyan]{self.gfa_name}[/cyan] (Len โ‰ฅ {filter_min_len}bp)", box=box.ROUNDED, ) # Add columns with styles and justifications table.add_column("Depth", style="bold cyan", justify="center") table.add_column("Segments", style="green", justify="center") table.add_column("Percentage", style="green", justify="center") table.add_column("Filtered Segments", style="yellow", justify="center") table.add_column("Filtered %", style="yellow", justify="center") # Get the correct filtered column names (depends on the filter) filtered_count_col = f"Filtered_Count_Len>={filter_min_len}" filtered_pct_col = f"Filtered_Percentage_Len>={filter_min_len}" if total_raw_segments_counted > 0: filtered_pct = (total_filtered_segments_counted / total_raw_segments_counted) * 100 else: filtered_pct = 0.0 # Summary above the table shared_console.rule("\n[green]Summary[/green]") shared_console.print( f"[cyan]Total segments analyzed:[/cyan] {total_raw_segments_counted:,}" ) shared_console.print( f"[cyan]Total segments passing length filter:[/cyan] {total_filtered_segments_counted:,} " f"({filtered_pct:.2f}%)\n" ) # Add data rows for _, row in df_stats_sorted.iterrows(): table.add_row( f"{row['Segment_Depth']:.0f}", f"{int(row['Raw_Count']):,}", f"{row['Raw_Percentage']:.2f}%", f"{int(row[filtered_count_col]):,}", f"{row[filtered_pct_col]:.2f}%" ) # Print the table shared_console.print(table) except Exception as e: print(e) self.logger.error(e, exc_info=True)
[docs] def get_specific_and_shared_segments( self, samples_list_A: List[str], samples_list_B: Optional[List[str]] = None, filter_min_len: Optional[int] = None, output_csv: Optional[bool] = None, ) -> Tuple[Set[str], Set[str]]: """ Identifies and counts segments that are shared among one group of samples and specific to that group relative to a second group. Parameters ---------- samples_list_A : List[str] A list of sample names. A segment is "shared" if it is present in ALL samples in this list. samples_list_B : Optional[List[str]], optional An optional second list of sample names. If provided, a segment is "specific" if it is shared by all in `samples_list_A` AND absent from ALL samples in this list. filter_min_len : Optional[int], optional If set, only segments with a length greater than or equal to this value will be considered. output_csv : Optional[bool], optional If True, the function will return sets of the shared and specific segment IDs. Returns ------- Tuple[Set[str], Set[str]] - A set of segment IDs that are shared by all samples in `samples_list_A`. - A set of segment IDs that are specific to `samples_list_A` relative to `samples_list_B`. (This set is a subset of the first one). """ if not samples_list_A: self.logger.error("samples_list_A cannot be empty for specific/shared segment analysis.") return set(), set() set_A = set(samples_list_A) set_B = set(samples_list_B) if samples_list_B else set() segment_list_shared: Set[str] = set() segment_list_specific: Set[str] = set() self.logger.info(f"Analyzing segments: Shared by ALL in {samples_list_A}.") if samples_list_B: self.logger.info( f"Also finding segments specific to list A (and absent from ALL in {samples_list_B}).") if filter_min_len is not None: self.logger.info(f"Applying length filter: >= {filter_min_len}bp.") counts = Counter() # "shared_A_count", "shared_A_length", "specific_to_A_count", "specific_to_A_length" total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: self.logger.error("BAM file empty/unreadable. Cannot perform specific/shared analysis.") return set(), set() spec_shared_task_id = self.progress.add_task( "Specific/Shared Segment Analysis", total=total_segments_in_bam ) segments_missing_sw_tag = 0 with AlignmentFile( self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads ) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(spec_shared_task_id) if not segment.has_tag("SW"): segments_missing_sw_tag += 1 continue seg_length = segment.query_length if filter_min_len is not None and seg_length < filter_min_len: continue # Skip segment if it doesn't meet length criteria sw_tag_value = segment.get_tag("SW") samples_in_segment_gfa = set(s.split(";")[0] for s in sw_tag_value.split(',') if s) # Check if segment is shared by all samples in set_A if set_A.issubset(samples_in_segment_gfa): counts["shared_A_count"] += 1 counts["shared_A_length"] += seg_length if output_csv: segment_list_shared.add(segment.query_name) # If set_B is provided, check for specificity (present in A, absent in B) if samples_list_B: # samples_list_B implies set_B is not empty # isdisjoint checks if set_B and samples_in_segment_gfa have no common elements if set_B.isdisjoint(samples_in_segment_gfa): counts["specific_to_A_count"] += 1 counts["specific_to_A_length"] += seg_length if output_csv: segment_list_specific.add(segment.query_name) if segments_missing_sw_tag > 0: self.logger.warning(f"{segments_missing_sw_tag} segments were skipped due to missing SW tag.") total_segments_analyzed = total_segments_in_bam - segments_missing_sw_tag # 1. Prepare the data (calculations are the same) total_segments_analyzed = max(1, total_segments_analyzed) # Avoid division by zero percent_shared_A = (counts["shared_A_count"] / total_segments_analyzed) * 100 if samples_list_B: percent_specific_to_A = (counts["specific_to_A_count"] / total_segments_analyzed) * 100 # 2. Create an inner table for perfect alignment # We use a borderless, headerless table to simulate a key-value list. results_table = Table(box=None, show_header=False, expand=True) results_table.add_column("Metric", style="cyan", no_wrap=True) results_table.add_column("Value", justify="right") # 3. Add information about shared segments # Using box-drawing characters for a "tree" effect results_table.add_row(f"[bold]Segments shared by {len(samples_list_A)} in list {samples_list_A}[/bold]", "") results_table.add_row(" โ”œโ”€ Count", f"{counts['shared_A_count']:,} / {total_segments_analyzed:,}") results_table.add_row(" โ”œโ”€ Percentage", f"[bold green]{percent_shared_A:.2f}%[/bold green]") results_table.add_row(" โ””โ”€ Total Length", f"{counts['shared_A_length']:,} bp") # 4. Add information about specific segments (if applicable) if samples_list_B: # An empty row for visual spacing results_table.add_row() results_table.add_row(f"[bold]Segments specific to {len(samples_list_A)} in list {samples_list_A}[/bold] and absent in {len(samples_list_B)} in list {samples_list_B}", "") results_table.add_row(" โ”œโ”€ Count", f"{counts['specific_to_A_count']:,} / {total_segments_analyzed:,}") results_table.add_row(" โ”œโ”€ Percentage", f"[bold yellow]{percent_specific_to_A:.2f}%[/bold yellow]") results_table.add_row(" โ””โ”€ Total Length", f"{counts['specific_to_A_length']:,} bp") # 5. Create a Panel to wrap everything and display it summary_panel = Panel( results_table, title="[bold blue]๐Ÿ“Š Shared & Specific Segment Analysis[/bold blue]", border_style="blue", expand=False # The panel will fit its content ) shared_console.print(summary_panel) return segment_list_shared, segment_list_specific
[docs] def get_segments_and_positions_by_depth( self, total_gfa_samples: int, input_as_number: bool, lower_bound_depth: int, upper_bound_depth: int, filter_min_len: int, bed_path: Path = None, ) -> Tuple[Dict[str, int], Dict[str, Dict[str, List[Tuple[str, int, int]]]]]: """ Finds segments within a specific depth range and retrieves their genomic positions from BED files. This function performs two main steps: 1. Scans the BAM file to identify segments that meet the specified depth and length criteria. 2. For those segments, it efficiently queries the relevant BED files to find their exact genomic coordinates (chromosome, start, end). Parameters ---------- total_gfa_samples : int Total number of unique samples in the GFA, used for percentage calculations. input_as_number : bool If True, depth bounds are absolute counts; if False, they are percentages of `total_gfa_samples`. lower_bound_depth : int Minimum sample depth (count or percentage) for a segment to be included. upper_bound_depth : int Maximum sample depth (count or percentage) for a segment to be included. filter_min_len : int Minimum length in base pairs for a segment to be considered. bed_path : Path, optional Path to the directory containing the sample-specific BED files. Returns ------- Tuple[Dict[str, int], Dict[str, Dict[str, List[Tuple[str, int, int]]]]] A tuple containing two dictionaries: 1. segments_with_depth: {segment_id: depth} for all segments matching the criteria. 2. segment_locations: {segment_id: {sample_name: [(chrom, start, end), ...]}} """ # --- PART 1: Find segments in the BAM file (your new, excellent logic) --- if bed_path: self.bed_path = bed_path if not input_as_number: abs_lower_bound = round((lower_bound_depth * total_gfa_samples) / 100) abs_upper_bound = round((upper_bound_depth * total_gfa_samples) / 100) self.logger.info( f"Depth range (percentage): {lower_bound_depth}%-{upper_bound_depth}% -> " f"Effective count range: [{abs_lower_bound}, {abs_upper_bound}]." ) else: abs_lower_bound = lower_bound_depth abs_upper_bound = upper_bound_depth self.logger.info(f"Depth range (absolute): [{abs_lower_bound}, {abs_upper_bound}] samples.") if abs_lower_bound > abs_upper_bound: self.logger.error(f"Lower bound ({abs_lower_bound}) > upper bound ({abs_upper_bound}). Aborting.") return {}, {} self.logger.info(f"Applying minimum length filter: >= {filter_min_len} bp.") total_segments_in_bam = self._get_total_segments_in_bam() if total_segments_in_bam == 0: return {}, {} segments_with_depth: Dict[str, int] = {} samples_to_scan: Set[str] = set() # We collect the relevant samples bam_task_id = self.progress.add_task( f"Scanning BAM for segments in depth range [{abs_lower_bound}-{abs_upper_bound}]", total=total_segments_in_bam ) with AlignmentFile(self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads) as bam_file, self.progress: for segment in bam_file.fetch(until_eof=True): self.progress.advance(bam_task_id) if segment.query_length < filter_min_len or not segment.has_tag("SW"): continue unique_samples = set(s.split(";")[0] for s in segment.get_tag("SW").split(',') if s) depth = len(unique_samples) if abs_lower_bound <= depth <= abs_upper_bound: segments_with_depth[segment.query_name] = depth samples_to_scan.update(unique_samples) # We add the samples to be scanned self.logger.info(f"Found {len(segments_with_depth)} segments meeting criteria. Now finding their positions.") # --- PART 2: Find positions for the found segments if not segments_with_depth: return {}, {} segment_locations = self._find_positions_for_segments( segment_ids=set(segments_with_depth.keys()), sample_list=list(samples_to_scan) ) return segments_with_depth, segment_locations
def _find_positions_for_segments(self, segment_ids: Set[str], sample_list: List[str]) -> Dict[ str, Dict[str, List[Tuple[str, int, int]]]]: """ Efficiently finds all genomic positions for a given set of segments across a list of samples. Uses a streaming awk command for maximum performance and low memory usage. """ locations = defaultdict(lambda: defaultdict(list)) with tempfile.NamedTemporaryFile(mode='w+', delete=True, suffix="_segment_ids.txt") as tmp_file: tmp_file.write('\n'.join(segment_ids)) tmp_file.flush() segment_id_filepath = tmp_file.name def _process_sample_bed(sample_name): sample_bed_path = self.bed_path / f"{sample_name}.bed" if not sample_bed_path.exists(): return {} awk_script = 'NR==FNR{ids[$1]; next} {id=$4; sub(/^[+-]/,"",id); if(id in ids) print $0}' cmd = ['awk', awk_script, segment_id_filepath, str(sample_bed_path)] sample_results = defaultdict(lambda: defaultdict(list)) try: with subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1) as process: for line in process.stdout: parts = line.strip().split('\t') if len(parts) < 4: continue seg_id = parts[3].strip("+-") # We store the positions in the desired tuple format pos_tuple = (parts[0], int(parts[1]), int(parts[2])) sample_results[seg_id][sample_name].append(pos_tuple) if process.returncode != 0: self.logger.error(f"Awk failed for '{sample_name}' (code {process.returncode}).") except Exception as e: self.logger.error(f"Awk execution failed for '{sample_name}': {e}", extra={"markup": False}) return sample_results # Run the processing in parallel with ThreadPoolExecutor(max_workers=self.threads) as executor, self.progress: bed_task_id = self.progress.add_task("Querying BED files for positions", total=len(sample_list)) futures = {executor.submit(_process_sample_bed, sample): sample for sample in sample_list} for future in as_completed(futures): partial_result = future.result() for seg_id, sample_map in partial_result.items(): for sample_name, pos_list in sample_map.items(): locations[seg_id][sample_name].extend(pos_list) self.progress.advance(bed_task_id) return locations
[docs] def export_nodes_to_csv( self, output_csv_path: Path, core_threshold_percent: float = 0.95, ) -> None: """ Exports enhanced node information to a CSV file using an efficient, single-pass approach. This method aggregates data in memory before creating a final Pandas DataFrame, making it much more memory-efficient than building a list of all records. It correctly parses the 'SW' tag to build a list of samples for each unique node. Parameters ---------- output_csv_path : Path The path where the output CSV file will be saved. core_threshold_percent : float, optional The percentage of total samples above which a node is considered 'core'. Defaults to 0.95. """ self.logger.info(f"Starting CSV export for Bandage to: {output_csv_path}") # --- Step 1: Single Pass Data Aggregation --- # We will iterate through the BAM file once and build a dictionary. # This avoids storing millions of redundant rows in memory. # Structure: { "node_name": {"length": int, "samples": set_of_samples} } nodes_data: Dict[str, Dict] = {} all_samples: Set[str] = set() self.logger.info("Performing BAM file to aggregate node data...") try: with AlignmentFile(self.bam_path.as_posix(), "rb", check_sq=False, threads=self.threads) as bam_file: for segment in bam_file.fetch(until_eof=True): node_name = segment.query_name # Since the SW tag is the same for all occurrences of a unique node, # we only need to process each unique node name once. if node_name in nodes_data: continue samples_for_this_node = set() # **CORRECTED LOGIC TO PARSE THE SW TAG** if segment.has_tag("SW"): sw_tag_value = segment.get_tag("SW") # 1. Split by comma to get individual records like "sampleA;chr1;0" records = sw_tag_value.split(',') # 2. For each record, split by semicolon and take the first part (the sample name) # A set comprehension handles this cleanly and ensures uniqueness. samples_for_this_node = {record.split(';')[0] for record in records} # Store all aggregated data for this unique node nodes_data[node_name] = { "length": segment.query_length, "samples": samples_for_this_node } # Also, update the set of all samples found in the file all_samples.update(samples_for_this_node) except Exception as e: self.logger.error(f"Failed to read or process BAM file '{self.bam_path}': {e}") return self.logger.info(f"Aggregated data for {len(nodes_data)} unique nodes from {len(all_samples)} total samples.") if not nodes_data: self.logger.warning("No node data was aggregated. The output CSV will be empty.") # Create an empty file to signify completion DataFrame().to_csv(output_csv_path, index=False) return # --- Step 2: Create Pandas DataFrame from Aggregated Data --- self.logger.info("Creating DataFrame and calculating final metrics...") # Convert the aggregated dictionary into a list of records for the DataFrame records = [] for node_name, data in nodes_data.items(): records.append({ "Node_ID": node_name, "Length_bp": data["length"], "Depth_Sample_Count": len(data["samples"]), "Sample_List": ";".join(sorted(list(data["samples"]))) }) df = DataFrame(records) # --- Step 3: Add Category Column and Export --- total_samples = len(all_samples) core_threshold_count = int(total_samples * core_threshold_percent) if total_samples > 0 else 0 # Use efficient Pandas operations to add the category column df['Category'] = 'shell' # Default value df.loc[df['Depth_Sample_Count'] >= core_threshold_count, 'Category'] = 'core' df.loc[df['Depth_Sample_Count'] == 1, 'Category'] = 'private' # Reorder columns for a clean output df = df[["Node_ID", "Length_bp", "Depth_Sample_Count", "Category", "Sample_List"]] try: df.to_csv(output_csv_path, index=False) except IOError as e: self.logger.error(f"Failed to write CSV file to '{output_csv_path}': {e}")