# Standard library imports
import asyncio
import gzip
import logging
import sqlite3
import subprocess
import tempfile
import re
from collections import Counter, OrderedDict, defaultdict, namedtuple
from concurrent.futures import ThreadPoolExecutor, wait, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Set
# Third-party imports
import aiofiles
import aiosqlite
from numpy import mean, median, std, histogram
from pandas import set_option, DataFrame, merge
import uvloop
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pybedtools import BedTool, cleanup
from pysam import AlignedSegment, AlignmentFile, index, qualitystring_to_array, view
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
TaskID,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich import box
from natsort import natsorted
# Local application/library specific imports
from .logger_config import shared_console # Assuming shared_console is used for rich printing
from .useful_function import RC_TRANS, reverse_complement_string
# Apply uvloop event loop policy
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Configure display precision of floats for pandas
set_option("display.precision", 2)
LinkInfo = namedtuple(
"LinkInfo",
(
"seg_id_1",
"orient_seg_1",
"orient_key_seg_1",
"seg_id_2",
"orient_seg_2",
"orient_key_seg_2",
),
)
"""
Information about a link between two segments.
Attributes
----------
seg_id_1 : str
Identifier of the first segment.
orient_seg_1 : int
Orientation of the first segment (+1 or -1).
orient_key_seg_1 : int
Orientation key for the first segment (often the same as orient_seg_1).
seg_id_2 : str
Identifier of the second segment.
orient_seg_2 : int
Orientation of the second segment (+1 or -1).
orient_key_seg_2 : int
Orientation key for the second segment (often the same as orient_seg_2).
"""
[docs]
@dataclass
class SubGraph:
"""
A class representing a subgraph in a genomic analysis pipeline.
It handles operations like BED file intersection, GFA walk/link/segment building,
and FASTA sequence generation for a specific sample or region.
Attributes
----------
bam_path : Path
The file path to the BAM file.
bed_path : Path
The file path to the BED file.
logger : logging.Logger
The logger instance for logging messages. Defaults to a logger named "GraTools".
sample_name : Optional[str]
The name of the sample, derived from `bed_path`. Default to None.
sample_name_query : Optional[str]
The name of the sample being queried. Default to None.
chromosome_query : Optional[str]
The chromosome name for the query. Default to None.
start_query : Optional[int]
The start position on the chromosome for the query. Default to None.
stop_query : Optional[int]
The stop position on the chromosome for the query. Default to None.
offset_first : int
The offset for the first segment in the query region. Default to 0.
offset_last : int
The offset for the last segment in the query region. Default to 0.
add_start_bases_first_segment : int
Additional start bases for the first segment (usage seems specific, consider clarifying). Default to 0.
intersect_bed : Optional[BedTool]
The BedTool object for intersected BED regions. Default to None.
segment_id_set : Set[str]
A set storing unique segment IDs encountered during walk building.
segment_id_first_query : Optional[str]
The ID of the first segment in the query region. Default to None.
segment_id_first_strand : Optional[str]
The strand ('+' or '-') of the first segment ID in the query. Default to None.
segment_id_last_query : Optional[str]
The ID of the last segment ID in the query region. Default to None.
segment_id_last_strand : Optional[str]
The strand ('+' or '-') of the last segment ID in the query. Default to None.
segment_id_first : Optional[str]
The first segment ID encountered for the current sample (might be same as query). Default to None.
segment_id_last : Optional[str]
The last segment ID encountered for the current sample (might be same as query). Default to None.
works_path : Optional[Path]
The working directory path. Defaults to None.
merge : Optional[int]
Merge distance parameter for BED operations (e.g., `bedtools merge -d`). Default to None.
build_fasta_flag : Optional[bool]
Flag indicating whether FASTA sequences should be built. Default to False.
gfa_walk_list : List[str]
A list of GFA walk strings (W lines). Default to an empty list.
gfa_link_list : List[str]
A list of GFA link strings (L lines). Default to an empty list.
gfa_segment_list : List[str]
A list of GFA segment strings (S lines). Default to an empty list.
dict_segments_samples : defaultdict[str, List[str]]
A dictionary mapping segment IDs to a list of sample identifiers. Default to an empty defaultdict.
dict_segments_sequence : defaultdict[str, str]
A dictionary mapping segment IDs to their sequences. Default to an empty defaultdict.
sequences_list : List[SeqRecord]
A list of Biopython SeqRecord objects for generated FASTA sequences. Default to an empty list.
progress_dict : Optional[Dict]
A dictionary to track progress, typically for multi-processing. Default to None.
task_id : Optional[TaskID]
The task ID for progress tracking with `rich.progress`. Default to None.
regions : Optional[List[Dict[str, Any]]]
A list of regions (dictionaries with 'chromosome', 'start', 'stop'). Default to None.
intersected_results_by_regions : Optional[BedTool]
A BedTool object containing combined intersected results for all regions. Default to None.
"""
bam_path: Path
bed_path: Path
logger: logging.Logger = field(
default_factory=lambda: logging.getLogger("GraTools")
)
sample_name: Optional[str] = None
sample_name_query: Optional[str] = None
chromosome_query: Optional[str] = None
start_query: Optional[int] = None
stop_query: Optional[int] = None
offset_first: int = 0
offset_last: int = 0
add_start_bases_first_segment: int = 0
intersect_bed: Optional[BedTool] = None
segment_id_set: Set[str] = field(default_factory=set, repr=False)
segment_id_first_query: Optional[str] = None
segment_id_first_strand: Optional[str] = None
segment_id_last_query: Optional[str] = None
segment_id_last_strand: Optional[str] = None
segment_id_first: Optional[str] = None # Tracks the first segment ID overall for this SubGraph instance
segment_id_last: Optional[str] = None # Tracks the last segment ID overall for this SubGraph instance
works_path: Optional[Path] = None
merge: Optional[int] = None
build_fasta_flag: Optional[bool] = False
gfa_walk_list: List[str] = field(default_factory=list, repr=False)
gfa_link_list: List[str] = field(default_factory=list, repr=False)
gfa_segment_list: List[str] = field(default_factory=list, repr=False)
dict_segments_samples: defaultdict = field(
default_factory=lambda: defaultdict(list), repr=False
)
dict_segments_sequence: defaultdict = field(
default_factory=lambda: defaultdict(str), repr=False # Ensure defaultdict has a type
)
sequences_list: List[SeqRecord] = field(default_factory=list, repr=False)
progress_dict: Optional[Dict] = None
task_id: Optional[TaskID] = None # Changed from int to TaskID for rich.progress type hint
regions: Optional[List[Dict[str, Any]]] = None
intersected_results_by_regions: Optional[BedTool] = None
def __post_init__(self) -> None:
"""Initialize sample_name from bed_path."""
suffixes = "".join(self.bed_path.suffixes) # e.g., ['.bed', '.gz'] for file.bed.gz
if ".bed.gz" in suffixes:
# For name.bed.gz, self.bed_path.name.replace(".bed.gz", "")
self.sample_name = self.bed_path.name.removesuffix(".bed.gz")
elif ".bed" in suffixes:
# For name.bed, self.bed_path.name.replace(".bed", "")
self.sample_name = self.bed_path.name.removesuffix(".bed")
else:
self.logger.error(f"The file {self.bed_path} is not a recognized BED file type.")
[docs]
def compute_intersection(self) -> None:
"""
Compute the intersection of BED regions for each region in self.regions.
Store the results in `self.intersected_results_by_regions`.
"""
all_bed = None
if not self.regions: # Default region if none provided
self.regions = [
{
"chromosome": self.chromosome_query,
"start": self.start_query,
"stop": self.stop_query,
}
]
if not isinstance(self.regions, list) or not all(
isinstance(region, dict) and
all(k in region for k in ["chromosome", "start", "stop"])
for region in self.regions
):
self.logger.error(
"Invalid regions provided. Must be a list of dictionaries, "
"each with 'chromosome', 'start', and 'stop' keys."
)
return # Exit if regions are invalid
# Load BED file
try:
all_bed = BedTool(self.bed_path.as_posix())
except Exception as e: # pybedtools can raise various exceptions
self.logger.error(f"Failed to load BED file '{self.bed_path}': {e}")
return # Exit if BED file cannot be loaded
# 3. Build a single BedTool for all regions, with index in the 4th column
lines = []
for idx, reg in enumerate(self.regions):
lines.append(f"{reg['chromosome']}\t{reg['start']}\t{reg['stop']}\t{idx}")
regions_bed = BedTool("\n".join(lines), from_string=True)
# 4. Single intersect call and saveas()
try:
intersect_all = all_bed.intersect(regions_bed, wa=True, wb=True)
tb = intersect_all.saveas() # creates a temporary file on disk
self.intersected_results_by_regions = tb
if intersect_all.count() <= 0:
# It's not an error for an intersection to be empty, but a warning/info might be useful.
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 7, "total": 7, # Mark as complete
"description": f"[red]'{self.sample_name}': No intersection found for region, aborted.",
}
self.logger.error(f"'{self.sample_name}': No intersection found for region {self.regions}")
except Exception as e:
self.logger.error(f"Error during single intersection: {e}")
return
[docs]
def build_walks(self) -> None:
"""Build GFA walks (W lines) and links (L lines) from the intersected BED regions."""
self.logger.info(f"'{self.sample_name}': Building subgraph GFA walks and links")
self.segment_id_set = set() # Reset for current build
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 5,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Building walks and links",
}
old_region = None
# defaultdict for walks: Key= (chrom_name, haplo_idx), Value= dict of fragments
# Fragment dict: Key= chrom_fragment_id, Value= data for this fragment
walk_dict = defaultdict(
lambda: defaultdict(
lambda: {
"seg_list": [],
"start": float("inf"), # Chromosomal start of the walk part
"stop": float("-inf"), # Chromosomal end of the walk part
}
)
)
# Iterate region by region
for intersect_region_bedtool in self.intersected_results_by_regions:
fields = intersect_region_bedtool
# fields = colonnes A + colonnes B (4 cols: chrom,start,stop,index)
region_idx = int(fields[-1])
if region_idx != old_region:
old_region = region_idx
region_start = None
chromosome_name = fields.chrom
segment_start_on_chr = fields.start # Start of this segment on the chromosome
segment_stop_on_chr = fields.stop # End of this segment on the chromosome
# GFA segment ID with orientation (e.g., +seg1, -seg2)
oriented_segment_id = fields[3]
# Fragment identifier on the chromosome (e.g., path fragment from original BED)
chromosome_fragment_id = fields[4]
haplotype_index = fields[5]
if region_start is None:
region_start = segment_start_on_chr
# Add segment ID (without orientation) to the set of all segments
self.segment_id_set.add(str(oriented_segment_id[1:]))
# Store first/last segment info if this is the query sample
if self.sample_name == self.sample_name_query:
if not self.segment_id_first_query: # First segment encountered for the query sample
self.segment_id_first_query = oriented_segment_id[1:]
self.segment_id_first_strand = oriented_segment_id[0]
self.segment_id_first = self.segment_id_first_query # Also set the general first segment
# Always update last segment for query sample to track the true end
self.segment_id_last_query = oriented_segment_id[1:]
self.segment_id_last_strand = oriented_segment_id[0]
self.segment_id_last = self.segment_id_last_query # Also set the general last segment
# Group segments by (chromosome, haplotype, fragment_id)
key = f"{chromosome_name}:{haplotype_index}:{region_idx}" # Walk identifier part
fragment_data = walk_dict[key][chromosome_fragment_id]
fragment_data["seg_list"].append(oriented_segment_id)
fragment_data["start"] = region_start
fragment_data["stop"] = segment_stop_on_chr
# Process collected walks
for chr_name_haplo_key, fragments_dict in walk_dict.items():
chr_name, haplotype_idx, _ = chr_name_haplo_key.split(":")
for fragment_id, data in fragments_dict.items():
# These are the actual genomic coordinates covered by this piece of the walk
walk_genomic_start = data["start"]
walk_genomic_stop = data["stop"]
ordered_segments = data["seg_list"] # Segments in order of appearance for this walk fragment
# Create GFA W line components
walk_line_parts = [
"W",
self.sample_name,
haplotype_idx,
chr_name,
str(walk_genomic_start),
str(walk_genomic_stop),
]
# Build L lines from these ordered segments
self._build_links(ordered_segments) # Pass the list of oriented segment IDs
# Create GFA path string (e.g., >seg1<seg2>seg3)
# The replace calls are for GFA spec compliance (> for forward, < for reverse)
gfa_path_string = "".join(s.replace("+", ">").replace("-", "<") for s in ordered_segments)
walk_line_parts.append(gfa_path_string)
# TODO: save to file for concat after on gratools
self.gfa_walk_list.append("\t".join(walk_line_parts))
def _build_links(self, segments_order: List[str]) -> None:
"""
Build GFA link lines (L lines) from an ordered list of oriented segment IDs.
Parameters
----------
segments_order : List[str]
Ordered list of GFA segment IDs, each prefixed with orientation
(e.g., ["+seg1", "-seg2", "+seg3"]).
"""
if len(segments_order) < 2: # Not enough segments to form a link
return
for i in range(len(segments_order) - 1):
seg1_oriented = segments_order[i]
seg2_oriented = segments_order[i + 1]
orient1_char = seg1_oriented[0] # '+' or '-'
seg1_id = seg1_oriented[1:]
orient2_char = seg2_oriented[0] # '+' or '-'
seg2_id = seg2_oriented[1:]
# Create L line: L <source_seg> <source_orient> <dest_seg> <dest_orient> <overlap_CIGAR>
# Assuming 0M overlap for simplicity, as is common in many GFA1.1 outputs.
txt_link = f"L\t{seg1_id}\t{orient1_char}\t{seg2_id}\t{orient2_char}\t0M"
self.gfa_link_list.append(txt_link)
[docs]
def build_segments(self) -> None:
"""Build GFA segments (S lines) from the BAM file for segments in `self.segment_id_set`."""
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 6,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Building segments",
}
self.logger.info(f"'{self.sample_name}': Building subgraph GFA segments")
if not self.segment_id_set:
self.logger.warning(f"'{self.sample_name}': No segment IDs for segment lines building. Skipping.")
self.gfa_segment_list = []
self.dict_segments_samples = defaultdict(list)
self.dict_segments_sequence = defaultdict(str)
return
bam_handler = GratoolsBam(bam_path=self.bam_path, logger=self.logger) # Pass logger
(
self.gfa_segment_list,
self.dict_segments_samples,
self.dict_segments_sequence,
) = bam_handler.build_segments(list_segments=self.segment_id_set)
[docs]
def filter_bed_with_awk(self) -> BedTool:
"""
Filter a BED file using an awk command line to extract only
lines containing an ID of interest.
"""
self.logger.info(f"'{self.sample_name}': Starting awk filtering for {len(self.segment_id_set)} IDs.")
segment_id_file = self.works_path / f"{self.sample_name}_segment_ids.txt"
filtered_bed_path = self.works_path / f"{self.sample_name}_filtered.bed"
# Save IDs to a file
with open(segment_id_file, "w") as f:
for seg_id in self.segment_id_set:
f.write(f"{seg_id}\n")
awk_script = f'''awk 'NR==FNR {{ids[$1]; next}} substr($4, 2) in ids' "{segment_id_file}" "{self.bed_path}" > "{filtered_bed_path}"'''
with open(filtered_bed_path, "w") as out_f:
subprocess.run(awk_script, shell=True, check=True, stdout=out_f)
self.logger.info(f"'{self.sample_name}': Filtered file with awk written to: {filtered_bed_path}")
return BedTool(filtered_bed_path), filtered_bed_path, segment_id_file
[docs]
def get_chr_pos(self, progress_dict: Optional[Dict], task_id: Optional[TaskID]) -> None:
"""
Identify chromosomal regions corresponding to segments in `self.segment_id_set`,
then compute intersections, and finally build walks and segments for these regions.
Parameters
----------
progress_dict : Optional[Dict]
A dictionary to track progress for multiprocessing.
task_id : Optional[TaskID]
The task ID for `rich.progress` tracking.
"""
self.progress_dict = progress_dict
self.task_id = task_id
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 2,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Searching corresponding chr pos",
}
if not self.segment_id_set:
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 7, "total": 7, # Mark as complete
"description": f"[red]'{self.sample_name}': the segment_id_set is empty, aborted.",
}
self.logger.error(
f"'{self.sample_name}': segment_id_set is empty. "
"Cannot determine chromosome positions. Ensure that build_walks (or equivalent setting segment_id_set) was called."
)
return
filtered_bed_all, filtered_bed_path, segment_id_file = self.filter_bed_with_awk()
if filtered_bed_all.file_type == 'empty':
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 7, "total": 7,
"description": f"[red]'{self.sample_name}': No segment in BED.",
}
self.logger.error(
f"'{self.sample_name}': No matching segment found in BED file '{self.bed_path}' "
f"for segment IDs: {list(self.segment_id_set)[:5]}..." # Log a few example IDs
)
self.progress_dict[self.task_id] = {
"progress": 3, "total": 7, # Mark as complete
"description": f"'{self.sample_name}': BED filtering successful, proceeding to merge..",
}
if self.works_path:
file_output_all = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_all.csv"
file_output_collapse = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_collapsed_{self.merge}.csv"
# Convert to DataFrame after merging (without -d, i.e., only abutting features)
merged_all_bed = filtered_bed_all.merge(c=[4, 5, 6], o=["collapse", "distinct", "distinct"]) # bedtools merge columns: name, score, strand
merged_all_bed.saveas(file_output_all)
merged_collapse_bed = merged_all_bed.merge(d=self.merge, c=[4, 5, 6], o=["collapse", "distinct", "distinct"])
merged_collapse_bed.saveas(file_output_collapse)
# Define regions for subsequent intersection based on the collapsed dataframe
self.logger.info(f"'{self.sample_name}': compute regions base on merge dataframe")
self.regions = [
{"chromosome": feature.chrom, "start": feature.start, "stop": feature.stop}
for feature in merged_collapse_bed
]
self.logger.debug(
f"'{self.sample_name}': Found {len(self.regions)} regions for processing. Example: {self.regions[0] if self.regions else 'None'}."
)
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 4,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Found corresponding chr pos",
}
if Path(filtered_bed_path).exists():
filtered_bed_path.unlink()
if Path(segment_id_file).exists():
segment_id_file.unlink()
# Process regions: compute intersection, build walks, build segments
self.compute_intersection()
self.build_walks() # build_walks populates self.segment_id_set
self.build_segments() # build_segments uses self.segment_id_set
# Log completion
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 7,
"total": 7,
"description": f"[green]'{self.sample_name}': End processing",
}
[docs]
def build_fasta(self) -> None:
"""
Build FASTA sequences from the GFA walks (`self.gfa_walk_list`).
This method recovers FASTA sequences by processing each walk,
extracting the constitutive segments, and applying specific offsets
if the current sample is the query sample (`self.sample_name == self.sample_name_query`).
The offsets trim segment sequences at the beginning or end of the walk
to match the precise query region boundaries (`self.start_query`, `self.stop_query`).
The resulting sequences are stored as `SeqRecord` objects in `self.sequences_list`.
"""
num_walks = len(self.gfa_walk_list)
self.logger.info(f"'{self.sample_name}': Building FASTA for {num_walks} walks.")
if not self.dict_segments_sequence:
self.logger.warning(f"'{self.sample_name}': Segment sequences dictionary is empty. Cannot build FASTA.")
return
if not self.gfa_walk_list:
self.logger.info(f"'{self.sample_name}': No GFA walks to process for FASTA generation.")
return
newly_created_records: List[SeqRecord] = [] # Store records for this call
for walk_line in self.gfa_walk_list:
try:
(
_w_char, # 'W' character
current_walk_sample_name, # Sample name from the W line
haplotype_index,
chromosome_name,
walk_genomic_start_str,
walk_genomic_stop_str,
gfa_path_specification, # e.g., ">seg1<seg2" or "seg1+seg2-"
) = walk_line.split("\t")
except ValueError as e:
self.logger.error(
f"'{self.sample_name}': Malformed walk line during FASTA build: '{walk_line}'. Error: {e}",
exc_info=True)
continue
walk_genomic_start = int(walk_genomic_start_str)
walk_genomic_stop = int(walk_genomic_stop_str)
# Initialize effective coordinates for this walk's sequence.
# These will be adjusted by offsets for the query sample.
effective_walk_chr_start = walk_genomic_start
effective_walk_chr_stop = walk_genomic_stop
# Calculate offsets if this is the query sample.
# These offsets represent how much to trim from the start/end of the *entire walk's sequence*
# if the walk extends beyond the query region.
if self.sample_name == self.sample_name_query:
if self.start_query is not None and self.stop_query is not None: # Ensure query bounds are set
self.offset_first = self.start_query - walk_genomic_start
self.offset_last = walk_genomic_stop - self.stop_query
else:
self.logger.warning(
f"'{self.sample_name}': Query sample set, but start/stop_query not fully set. Offsets may be incorrect.")
# Parse GFA path string (e.g., ">s1<s2" or "s1+s2-") into oriented segments
# GFA1 W line path: >s1>s2<s3... or s1+s2-s3+...
oriented_segments_in_path: List[Tuple[str, str]] = [] # List of (orientation_char, segment_id)
if ">" in gfa_path_specification[0] or "<" in gfa_path_specification[0]:
# Format like >s1<s2
matches = re.findall(r"([<>])([^<>]+)", gfa_path_specification)
for orient_map_char, seg_id_str in matches:
oriented_segments_in_path.append(("+" if orient_map_char == ">" else "-", seg_id_str))
else:
# Format like s1+s2- (less common in W lines but GFA spec allows segment names with +/-)
matches = re.findall(r"(.+?)([+-])", gfa_path_specification)
for seg_id_str, orient_char in matches:
oriented_segments_in_path.append((orient_char, seg_id_str))
if not oriented_segments_in_path:
self.logger.warning(
f"'{self.sample_name}': Could not parse path specification '{gfa_path_specification}' in walk line: '{walk_line}'. Skipping this walk for FASTA building.")
continue
num_segments_in_walk = len(oriented_segments_in_path)
# List to hold Bio.Seq fragments for constructing the full sequence of this walk
current_walk_seq_fragments: List[str] = []
self.logger.debug(
f"'{self.sample_name}': Processing walk on {chromosome_name} ({walk_genomic_start}-{walk_genomic_stop}) "
f"with {num_segments_in_walk} segments for FASTA."
)
for segment_idx, (segment_orient_char, segment_id) in enumerate(oriented_segments_in_path):
if segment_id not in self.dict_segments_sequence:
self.logger.warning(
f"'{self.sample_name}': Sequence for segment '{segment_id}' in walk not found. Skipping this segment.")
# Potentially add a placeholder like 'N's if a continuous sequence is critical,
# or simply skip, which might create a shorter/disjointed sequence.
continue
current_segment_seq = self.dict_segments_sequence[segment_id]
if segment_orient_char == "-":
current_segment_seq = reverse_complement_string(current_segment_seq)
# Apply complex offsetting logic from the original code IF this is the query sample
if self.sample_name == self.sample_name_query:
# These checks need self.segment_id_first_query, self.segment_id_first_strand, etc.
# to be correctly set (usually when the query SubGraph itself was processed).
# And self.offset_first / self.offset_last calculated above.
# Check if current segment is the designated "first query segment"
if segment_id == self.segment_id_first_query:
if segment_idx == 0: # Is it the *first segment of this walk*?
# And does its orientation match the expected start orientation?
if segment_orient_char == self.segment_id_first_strand:
if self.offset_first > 0: # Positive offset means query starts *after* walk's genomic start
if self.offset_first < len(current_segment_seq):
current_segment_seq = current_segment_seq[self.offset_first:]
effective_walk_chr_start = walk_genomic_start + self.offset_first
else: # Offset is >= segment length, segment is fully trimmed
self.logger.debug(
f"Offset_first ({self.offset_first}bp) consumes entire first segment '{segment_id}' ({len(current_segment_seq)}bp).")
current_segment_seq = ""
# else: offset_first <= 0, means query starts at or before walk start, no trim from beginning needed.
elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*?
# This case is for when the *first query segment* also happens to be the *last segment of the walk*,
# and its orientation is opposite to the expected start orientation.
# This is a very specific scenario, implying the walk might be very short or circular in a way.
if segment_orient_char != self.segment_id_first_strand:
# Here, offset_first is used to trim from the *end* of the segment.
# This implies offset_first was calculated based on the query *start*, but is applied to the *end*
# of this segment if it's the first_query_segment but found at the end of the walk in reverse.
if self.offset_first > 0: # If an offset from the start of the query was calculated
if self.offset_first < len(current_segment_seq):
current_segment_seq = current_segment_seq[
:-self.offset_first] # Trim from end
effective_walk_chr_stop = walk_genomic_stop - self.offset_first
else:
current_segment_seq = ""
# Check if current segment is the designated "last query segment"
if segment_id == self.segment_id_last_query:
if segment_idx == 0: # Is it the *first segment of this walk*?
# This case is for when the *last query segment* is also the *first segment of the walk*,
# and its orientation is opposite to the expected end orientation.
if segment_orient_char != self.segment_id_last_strand:
# Here, offset_last is used to trim from the *beginning* of the segment.
# offset_last was calculated based on query *end*.
if self.offset_last > 0: # Positive offset_last means query ends *before* walk's genomic end
if self.offset_last < len(current_segment_seq):
current_segment_seq = current_segment_seq[
self.offset_last:] # Trim from beginning
effective_walk_chr_start = walk_genomic_start + self.offset_last
else:
current_segment_seq = ""
elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*?
# And does its orientation match the expected end orientation?
if segment_orient_char == self.segment_id_last_strand:
if self.offset_last > 0:
if self.offset_last < len(current_segment_seq):
current_segment_seq = current_segment_seq[:-self.offset_last]
effective_walk_chr_stop = walk_genomic_stop - self.offset_last
else:
self.logger.error(
f"Offset_last ({self.offset_last}bp) consumes entire last segment '{segment_id}' ({len(current_segment_seq)}bp).")
current_segment_seq = ""
if len(current_segment_seq) > 0:
current_walk_seq_fragments.append(current_segment_seq)
# After processing all segments for this walk
if not current_walk_seq_fragments:
self.logger.debug(
f"'{self.sample_name}': No sequence fragment generated for walk on {chromosome_name} "
f"(ID: {current_walk_sample_name}-{haplotype_index}). Skipping FASTA record for this walk.")
continue
final_walk_sequence = Seq("".join(current_walk_seq_fragments))
if len(final_walk_sequence) == 0:
self.logger.debug(
f"'{self.sample_name}': Resulting sequence for walk on {chromosome_name} is empty. Skipping FASTA.")
continue
# Construct FASTA record ID and description
fasta_seq_id: str
fasta_description: str
if self.sample_name == self.sample_name_query:
# For the query sample, use precise query coordinates in the ID if available
# Ensure query parameters are fully defined for this formatting
q_sample = self.sample_name_query or "query_sample"
q_chrom = self.chromosome_query or chromosome_name # Fallback to walk's chrom if query_chrom not set
# Use effective_walk_chr_start/stop which were adjusted by offsets
fasta_seq_id = f"{q_sample}-{q_chrom}:{effective_walk_chr_start}-{effective_walk_chr_stop}"
fasta_description = f""
else:
# For other (target/non-query) samples
fasta_seq_id = f"{current_walk_sample_name}-{chromosome_name}:{effective_walk_chr_start}-{effective_walk_chr_stop}"
# Description includes length of this sample's sequence and the query region it corresponds to
seq_actual_len = effective_walk_chr_stop - effective_walk_chr_start
query_region_str = "N/A"
if self.sample_name_query and self.chromosome_query and \
self.start_query is not None and self.stop_query is not None:
query_region_str = f"{self.sample_name_query}-{self.chromosome_query}:{self.start_query}-{self.stop_query}"
fasta_description = f"SEQ_LEN={seq_actual_len}|QUERY={query_region_str}"
# Ensure ID is suitable for FASTA (e.g., no problematic spaces at end)
fasta_seq_id = fasta_seq_id.strip()
seq_record = SeqRecord(
final_walk_sequence,
id=fasta_seq_id,
name=fasta_seq_id, # Often same as id
description=fasta_description
)
newly_created_records.append(seq_record)
# After processing all walks
self.sequences_list.extend(newly_created_records) # Add all new records to the main list
self.logger.info(
f"'{self.sample_name}': Finished FASTA building. Added {len(newly_created_records)} new records. "
f"Total records in sequences_list: {len(self.sequences_list)}.")
[docs]
class AsyncGfaDatabase:
"""
Manage an asynchronous SQLite database for storing and querying GFA link data.
It uses `aiosqlite` for non-blocking operations within an asyncio event loop
and serializes writes via an internal FIFO queue to prevent SQLite lock contention.
Attributes
----------
db_file : Path
Path to the SQLite database file.
timeout : float
Maximum timeout (in seconds) for SQLite lock acquisition.
logger : logging.Logger
Logger instance.
_conn : Optional[aiosqlite.Connection]
Shared SQLite connection (or None if not connected).
_write_queue : asyncio.Queue
Asynchronous queue for batches of links to be inserted. Max size 100.
_sql_task : Optional[asyncio.Task]
Background task consuming the queue and writing to the database.
_shutdown : bool
Flag to signal shutdown to the writer task.
"""
def __init__(self, db_file: Path, timeout: float = 30.0):
"""
Initialize the instance without opening the connection.
Scheme is created upon the first call to `connect()`.
Args:
db_file (Path): Path to the SQLite file to be used as backend.
timeout (float, optional): Maximum wait time for SQLite locks (default 30.0s).
"""
self.db_file: Path = db_file
self.timeout: float = timeout # In seconds
self._shutdown: bool = False
self._conn: Optional[aiosqlite.Connection] = None
self._write_queue: asyncio.Queue[Optional[List[Tuple[str, int, int, str, int, int]]]] = asyncio.Queue(
maxsize=1000)
self._sql_task: Optional[asyncio.Task] = None
self.logger = logging.getLogger("GraTools") # More specific logger name
[docs]
async def connect(self) -> None:
"""
Connect to the SQLite db (if not already connected), configure PRAGMA settings,
create the 'links' table schema (if it doesn't exist), and start the SQL writer task.
This method is idempotent: if already connected, it does nothing.
"""
if self._conn:
return
self.logger.debug(f"Connecting to the SQLite database: {self.db_file.name}")
# Open the single connection
self._conn = await aiosqlite.connect(
self.db_file.as_posix(), timeout=self.timeout # timeout here is for connection
)
# Performance optimizations and configuration
await self._conn.execute("PRAGMA synchronous = OFF;") # Faster, but higher risk of corruption on crash
await self._conn.execute("PRAGMA journal_mode = WAL;") # Write-Ahead Logging for better concurrency
await self._conn.execute(
f"PRAGMA busy_timeout = {int(self.timeout * 1000)};") # busy_timeout is in milliseconds
# Create schema
await self._conn.execute("""
CREATE TABLE IF NOT EXISTS links (
seg_id_1 TEXT,
orient_seg_1 INTEGER, -- +1 for '+', -1 for '-'
orient_key_seg_1 INTEGER, -- Often same as orient_seg_1, purpose might be specific
seg_id_2 TEXT,
orient_seg_2 INTEGER,
orient_key_seg_2 INTEGER
);
""")
await self._conn.commit()
# Start the background insertion task
if self._sql_task is None or self._sql_task.done():
self._sql_task = asyncio.create_task(self._sql_writer())
self.logger.debug("AsyncGfaDatabase connected and writer task started/confirmed.")
async def _sql_writer(self) -> None:
"""
Background task that continuously fetches link batches from the queue and inserts them into the database.
It handles shutdown signals and manages SQLite transactions for batch inserts.
"""
assert self._conn is not None, "Connection must be initialized before writer starts"
self.logger.debug("SQL writer task started.")
processed_data_batches = 0 # Counter for actual data batches
while True: # The exit condition is handled by internal breaks
batch = None
item_was_retrieved = False
try:
# Wait for an item from the queue with a timeout
# The timeout allows periodic checking of _shutdown
batch = await asyncio.wait_for(self._write_queue.get(), timeout=0.5)
item_was_retrieved = True
if batch is None: # Sentinel to stop the writer
self.logger.debug("SQL writer received end. Preparing to stop.")
break # Exit the main while loop
if not batch: # An empty batch (empty list), not the sentinel
# task_done will be called in finally
continue # Go to the next iteration to get another batch
# This is a valid data batch
processed_data_batches += 1
self.logger.debug(
f"SQL writer processing data batch {processed_data_batches} of size {len(batch)}. "
f"Queue size: {self._write_queue.qsize()}"
)
# Database insertion
await self._conn.executemany(
"""INSERT INTO links (
seg_id_1, orient_seg_1, orient_key_seg_1,
seg_id_2, orient_seg_2, orient_key_seg_2
) VALUES (?, ?, ?, ?, ?, ?);""",
batch,
)
await self._conn.commit()
# self.logger.debug(f"SQL writer successfully committed data batch {processed_data_batches}.")
except asyncio.TimeoutError:
# Timeout while waiting for an item from the queue
if self._shutdown and self._write_queue.empty():
self.logger.info("SQL writer: Timeout on get() with _shutdown True and queue empty. Stopping.")
break # Exit the main while loop
continue
except sqlite3.Error as e:
self.logger.error(
f"SQLite error during processing of batch (approx. #{processed_data_batches + 1}): {e}",
exc_info=True)
if self._conn.in_transaction:
try:
await self._conn.rollback()
self.logger.warning("SQL writer rolled back transaction due to SQLite error.")
except sqlite3.Error as rb_err:
self.logger.error(f"Error during SQLite rollback: {rb_err}")
# Continue to try and process other batches, the current batch is lost without re-queuing.
except Exception as e:
self.logger.error(f"Unexpected error in SQL writer (approx. batch #{processed_data_batches + 1}): {e}",
exc_info=True)
# Handle transaction if necessary, as for sqlite3.Error
finally:
if item_was_retrieved:
self._write_queue.task_done()
# if batch is not None and batch: # If it was an actual data batch
# self.logger.debug(f"SQL writer marked data batch {processed_data_batches} as done.")
# No special log for an empty batch after task_done
self.logger.debug("SQL writer task has stopped.")
[docs]
async def batch_insert_links(
self, links: List[Tuple[str, int, int, str, int, int]]
) -> None:
"""
Enqueue a batch of links for non-blocking insertion.
Must be called after `await connect()`.
Args:
links : List of tuples, each representing a link:
(seg_id_1, orient_seg_1, orient_key_seg_1,
seg_id_2, orient_seg_2, orient_key_seg_2).
"""
if not self._conn or (self._sql_task and self._sql_task.done()):
await self.connect() # Will also restart task if needed and done
if not links: # Do not put empty list on queue if it's not a sentinel
return
self.logger.debug(
f"SQL writer adding batch of {len(links)} links to write queue. Current queue size: {self._write_queue.qsize()}"
)
await self._write_queue.put(links)
[docs]
async def create_indexes(self) -> None:
"""
Create indexes on seg_id_1 and seg_id_2 to accelerate queries.
Should ideally be called after all data insertions are complete.
"""
if not self._conn:
await self.connect()
self.logger.info("Creating indexes on links table.")
try:
await self._conn.execute(
"CREATE INDEX IF NOT EXISTS idx_seg_id_1 ON links(seg_id_1);"
)
await self._conn.execute(
"CREATE INDEX IF NOT EXISTS idx_seg_id_2 ON links(seg_id_2);"
)
await self._conn.commit()
except sqlite3.Error as e:
self.logger.error(f"Failed to create indexes: {e}", exc_info=True)
[docs]
async def query_links_by_segment(self, segment_id: str) -> List[Tuple[Any, ...]]:
"""
Retrieves all links where `segment_id` appears as seg_id_1 or seg_id_2.
Args:
segment_id : The ID of the target segment.
Returns:
List of tuples, each representing a full link row from the database.
"""
if not self._conn:
await self.connect()
self.logger.debug(f"Querying links for segment: {segment_id}")
try:
async with self._conn.execute( # `async with cursor` is good practice
"SELECT * FROM links WHERE seg_id_1 = ? OR seg_id_2 = ?;",
(segment_id, segment_id),
) as cursor:
results = await cursor.fetchall()
self.logger.debug(f"Found {len(results)} links for segment {segment_id}.")
return results
except sqlite3.Error as e:
self.logger.error(f"Error querying links for segment {segment_id}: {e}", exc_info=True)
return []
[docs]
async def test_query_links(
self, segment_id: str
) -> List[Tuple[str, str, int, int]]:
"""
Retrieve and categorize links related to a given segment.
- "before": links where `segment_id` is seg_id_2.
- "after": links where `segment_id` is seg_id_1.
Args:
segment_id : The segment to analyze.
Return:
List of tuples: (connected_segment_id, position_type, orient_seg_1, orient_seg_2).
"""
if not self._conn:
await self.connect()
self.logger.debug(f"Testing query links for segment: {segment_id}")
before_rows = []
after_rows = []
try:
# "before" links (segment_id is the destination)
async with self._conn.execute(
"SELECT seg_id_1, orient_seg_1, orient_seg_2 FROM links WHERE seg_id_2 = ?;",
(segment_id,),
) as cur_before:
before_rows = await cur_before.fetchall()
self.logger.debug(
f"Found {len(before_rows)} 'before' links for {segment_id}."
)
# "after" links (segment_id is the source)
async with self._conn.execute(
"SELECT seg_id_2, orient_seg_1, orient_seg_2 FROM links WHERE seg_id_1 = ?;",
(segment_id,),
) as cur_after:
after_rows = await cur_after.fetchall()
self.logger.debug(
f"Found {len(after_rows)} 'after' links for {segment_id}."
)
except sqlite3.Error as e:
self.logger.error(f"Error during test_query_links for {segment_id}: {e}", exc_info=True)
# Return whatever was fetched before the error, or empty
# This part of error handling depends on desired behavior.
# Merge and sort results
# The columns in `before_rows` are (seg_id_1, orient_seg_1, orient_seg_2)
# The columns in `after_rows` are (seg_id_2, orient_seg_1, orient_seg_2)
# The resulting tuple should be (connected_segment_id, "before"/"after", orient_of_link_seg1, orient_of_link_seg2)
merged = natsorted([(seg_id, "before", o1, o2) for seg_id, o1, o2 in before_rows] + \
[(seg_id, "after", o1, o2) for seg_id, o1, o2 in after_rows])
self.logger.debug(f"Total merged links for {segment_id}: {len(merged)}")
return merged
[docs]
async def find_children_and_grandchildren(
self, node_id: str
) -> Dict[str, List[str]]:
"""
Find direct successors (children) and second-degree successors (grandchildren) of a segment.
A child is `seg_id_2` where `node_id` is `seg_id_1`.
A grandchild is a child of a child.
Args:
node_id : The starting segment ID.
Returns:
A dictionary: {"children": [IDs], "grandchildren": [IDs]}.
"""
if not self._conn:
await self.connect()
self.logger.debug(f"Finding children and grandchildren for node: {node_id}")
children: List[str] = []
grandchildren: List[str] = []
try:
# Direct children
async with self._conn.execute(
"SELECT seg_id_2 FROM links WHERE seg_id_1 = ?;", (node_id,)
) as cur_children:
children = [r[0] for r in await cur_children.fetchall()] # r is a tuple e.g. ('seg_X',)
self.logger.debug(f"Found {len(children)} children for {node_id}.")
# Grandchildren (children of children)
if children:
# Create placeholders for the IN clause: (?, ?, ...)
placeholders = ",".join("?" for _ in children)
sql_grandchildren = f"SELECT seg_id_2 FROM links WHERE seg_id_1 IN ({placeholders});"
async with self._conn.execute(sql_grandchildren, children) as cur_gc:
grandchildren_tuples = await cur_gc.fetchall()
# Deduplicate grandchildren, as multiple children might lead to the same grandchild.
grandchildren = sorted(list(set(r[0] for r in grandchildren_tuples)))
self.logger.debug(
f"Found {len(grandchildren)} unique grandchildren for {node_id} (from {len(grandchildren_tuples)} raw)."
)
except sqlite3.Error as e:
self.logger.error(f"Error finding children/grandchildren for {node_id}: {e}", exc_info=True)
# Return lists as they are up to the point of error.
return {"children": children, "grandchildren": grandchildren}
[docs]
async def close(self) -> None:
"""
Properly shut down the database:
- Signal the SQL writer task to stop.
- Wait for the writer task to finish processing its queue (with timeout).
- Cancel the task if it doesn't finish in time.
- Create indexes (important to do this after all writes).
- Close the SQLite connection.
"""
self.logger.debug("Initiating AsyncGfaDatabase shutdown...")
self._shutdown = True # Signal writer to stop after draining queue
if self._write_queue:
self.logger.debug("Putting sentinel on SQL write queue.")
await self._write_queue.put(None) # Sentinel to stop the writer loop
self.logger.debug("Waiting for SQL queue to drain...")
try:
# Wait for all items in queue to be processed.
# Timeout for queue.join() should be generous if large batches or slow I/O expected.
await self._write_queue.join()
self.logger.debug("SQL queue drained successfully (all items processed).")
except asyncio.TimeoutError:
self.logger.warning("Timeout waiting for SQL queue to drain during close.")
except Exception as e:
self.logger.error(f"Error waiting for SQL queue to drain: {e}", exc_info=True)
if self._sql_task and not self._sql_task.done():
self.logger.debug("SQL writer task still running, attempting to cancel.")
self._sql_task.cancel()
try:
await self._sql_task # Wait for task to acknowledge cancellation
self.logger.debug("SQL writer task cancelled successfully.")
except asyncio.CancelledError:
self.logger.info("SQL writer task was cancelled as expected during close.")
except Exception as e:
self.logger.error(f"Error awaiting SQL writer task cancellation: {e}", exc_info=True)
elif self._sql_task and self._sql_task.done():
self.logger.debug("SQL writer task was already done.")
else:
self.logger.debug("No SQL writer task to shut down or task was never started.")
# Create indexes after all data is written and writer task is stopped
if self._conn: # Only if connection was ever established
await self.create_indexes()
try:
await self._conn.close()
except Exception as e: # Catch potential errors during aiosqlite close
self.logger.error(
f"Error closing SQLite connection: {e}", exc_info=True
)
finally:
self._conn = None # Ensure _conn is None after attempting to close
class AsyncBedWriter:
"""
Asynchronous BED file writer using `aiofiles`.
It buffers lines per sample and writes them in batches to separate BED files.
Attributes
----------
bed_dir : Path
Output directory for .bed files.
batch_size : int
Number of lines to buffer per sample before an automatic flush.
progress : Optional[Progress]
Rich Progress instance for displaying progress (if provided).
logger : logging.Logger
Logger instance.
_queue : asyncio.Queue[Tuple[str, List[str]]]
Internal queue for (sample_name, lines_to_write) tuples. Max size 100.
_shutdown : bool
Flag to signal shutdown to the writer loop.
_task : Optional[asyncio.Task]
The background asyncio task running the `_writer_loop`.
"""
def __init__(self, bed_dir: Path, batch_size: int = 10000,
progress: Optional[Progress] = None):
self.bed_dir: Path = bed_dir
self.batch_size: int = batch_size
self._queue: asyncio.Queue[Tuple[str, List[str]]] = asyncio.Queue(maxsize=1000)
self._shutdown: bool = False
self._task: Optional[asyncio.Task] = None
self.progress: Optional[Progress] = progress # Can be None
self.logger = logging.getLogger("GraTools")
def start(self) -> None:
"""Start the writer loop as a background task if not already running."""
if self._task is None or self._task.done():
self._shutdown = False # Reset shutdown flag if restarting
self.logger.debug("Starting AsyncBedWriter loop task.")
self._task = asyncio.create_task(self._writer_loop())
self.logger.debug("AsyncBedWriter task launched.")
else:
self.logger.debug("AsyncBedWriter task is already running.")
def enqueue(self, sample: str, lines: List[str]) -> None:
"""
Add a sample and its corresponding lines to the write queue.
Use `put_nowait` assuming the queue rarely fills; consider `await self._queue.put()`
if backpressure to the producer is acceptable when the queue is full.
Args:
sample (str): The sample name (used for filename).
lines (List[str]): A list of strings (lines) to write to the BED file.
Each string should include its own newline character if needed.
"""
if not lines:
self.logger.debug(f"AsyncBedWriter: enqueue called with empty lines for sample '{sample}'. Skipping.")
return
try:
self.logger.debug(
f"BED writer enqueuing {len(lines)} lines for sample '{sample}'. Current queue size: {self._queue.qsize()}"
)
self._queue.put_nowait((sample, lines)) # Can raise asyncio.QueueFull if maxsize is hit
except asyncio.QueueFull:
self.logger.error(
f"AsyncBedWriter queue is full. Failed to enqueue {len(lines)} lines for sample '{sample}'. "
"Consider increasing queue maxsize or slowing down producers."
)
# OPTIONAL: Implement retry logic or a different strategy for full queue.
async def _writer_loop(self) -> None:
"""
Consume the queue and writes lines to sample-specific BED files in batches.
Manage buffers for each sample.
"""
self.logger.debug("BED writer loop started.")
buffers: Dict[str, List[str]] = defaultdict(list)
processed_items_count = 0
# Ensure the output directory exists
try:
self.bed_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
self.logger.error(f"Failed to create BED output directory {self.bed_dir}: {e}. Writer loop cannot proceed.")
return
while not self._shutdown or not self._queue.empty():
sample_name_for_log = ""
try:
# Wait for an item, with a timeout to allow checking _shutdown flag
current_sample, new_lines = await asyncio.wait_for(self._queue.get(), timeout=0.5) # Increased timeout
sample_name_for_log = current_sample # For logging in except/finally
processed_items_count += 1
self.logger.debug(
f"BED writer received item {processed_items_count} for sample '{current_sample}' "
f"with {len(new_lines)} lines. Queue size: {self._queue.qsize()}"
)
buffers[current_sample].extend(new_lines)
# Flush buffer for this sample if it's full
if len(buffers[current_sample]) >= self.batch_size:
self.logger.debug(
f"BED writer flushing BED buffer for sample '{current_sample}' (size: {len(buffers[current_sample])})."
)
await self._flush_sample(current_sample,
buffers[current_sample]) # _flush_sample clears the buffer passed
# buffers[current_sample] is now empty
except asyncio.TimeoutError:
# No item from queue within timeout, check shutdown status
if self._shutdown and self._queue.empty():
break
# Continue polling if not shutting down or queue not empty
continue
except asyncio.CancelledError:
self.logger.info("BED writer loop was cancelled.")
# Perform a final flush before exiting due to cancellation
await self._flush_all(buffers)
raise # Re-raise CancelledError to signal cancellation to caller
except Exception as e:
self.logger.error(f"Error in BED writer loop processing item for '{sample_name_for_log}': {e}",
exc_info=True)
finally:
# This check is crucial: task_done should only be called if queue.get() succeeded.
# If TimeoutError or CancelledError occurred in wait_for, get() didn't return.
if 'current_sample' in locals(): # Indicates get() succeeded
self._queue.task_done()
# self.logger.debug(f"BED writer marked item for sample '{current_sample}' as done.")
del current_sample # Clean up to avoid using stale var if next iteration errors early
# Periodically flush all buffers if shutdown is requested (helps drain faster)
if self._shutdown and not self._queue.empty():
await self._flush_all(buffers) # Flush all to expedite if shutting down
# Final flush for any remaining data in buffers after loop finishes
self.logger.debug("BED writer loop finished. Performing final flush of all buffers.")
await self._flush_all(buffers)
self.logger.debug("BED writer loop has stopped.")
async def _flush_sample(self, sample: str, buffer_lines: List[str]) -> None:
"""Write the content of `buffer_lines` to the BED file for `sample` and clear `buffer_lines`."""
if not buffer_lines:
self.logger.debug(f"No lines to flush for sample '{sample}'.")
return
path = self.bed_dir / f"{sample}.bed"
# Parent directory creation is handled in _writer_loop start now.
try:
async with aiofiles.open(path, mode="a", encoding="utf-8") as f:
# Assuming lines in buffer_lines already have newline characters.
# If not, add them: await f.write("".join(line + "\n" for line in buffer_lines))
await f.writelines(buffer_lines)
# self.logger.debug(f"Successfully flushed {len(buffer_lines)} lines to '{path.name}'.")
buffer_lines.clear() # Clear the buffer that was passed, after successful write
except IOError as e: # More specific exception for file operations
self.logger.error(
f"IOError flushing sample '{sample}' to '{path}': {e}", exc_info=True
)
except Exception as e:
self.logger.error(
f"Unexpected error flushing sample '{sample}' to '{path}': {e}", exc_info=True
)
async def _flush_all(self, buffers: Dict[str, List[str]]) -> None:
"""Write all non-empty buffers to their respective files."""
self.logger.debug(f"Performing flush of all remaining buffers ({len(buffers)} total).")
# Iterate over a copy of items if modifying dict, but here we pass buffer to _flush_sample which clears it.
for sample_name, buf_list in list(buffers.items()): # list() for safe iteration if _flush_sample modified dict
if buf_list: # Only flush if buffer has content
await self._flush_sample(sample_name, buf_list)
if not buf_list and sample_name in buffers: # If buffer became empty (cleared by _flush_sample)
del buffers[sample_name] # Clean up empty entries from the main buffers dict
self.logger.debug("Flush of all buffers complete.")
async def shutdown(self) -> None:
"""
Signal the writer loop to shut down and wait for it to complete.
Ensure all pending data is flushed.
"""
self._shutdown = True
if self._task is not None and not self._task.done():
self.logger.debug("Waiting for BED writer task to complete...")
try:
# The writer loop itself handles queue draining on shutdown signal.
# We just need to wait for the task to finish.
await asyncio.wait_for(self._task, timeout=None)
self.logger.debug("BED writer task completed successfully after shutdown signal.")
except asyncio.TimeoutError:
self.logger.warning(
"Timeout waiting for BED writer task to complete during shutdown. Attempting cancellation.")
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
self.logger.info("BED writer task was cancelled during shutdown due to timeout.")
except Exception as e_inner:
self.logger.error(f"Exception after cancelling BED writer task: {e_inner}", exc_info=True)
except asyncio.CancelledError: # If shutdown itself was called from a cancelled context
self.logger.info("BED writer task was externally cancelled during its shutdown process.")
# May need to re-ensure final flush if task was hard-cancelled.
except Exception as e:
self.logger.error(
f"Exception while waiting for BED writer task during shutdown: {e}", exc_info=True
)
elif self._task and self._task.done():
self.logger.debug("BED writer task was already done prior to shutdown call.")
else:
self.logger.debug(
"No BED writer task to shut down (was never started or already None)."
)
[docs]
@dataclass
class GFA:
"""
Manage parsing of a GFA (Graphical Fragment Assembly) file, compute statistics,
and generate related files such as BAM-containing segments files and BED files for path per sample.
Fill an asynchronous database for links and asynchronous BED writer.
Attributes
----------
gfa_path : Path
Path to the input GFA file (can be .gfa or .gfa.gz).
threads : int, optional
Number of threads for operations like BAM file processing. Default is 1.
logger : logging.Logger
Logger object. Default is a logger named "GraTools".
gfa_name : Optional[str]
Name of the GFA file derived from `gfa_path` (without extensions). Auto-initialized.
version : Optional[str]
GFA version extracted from the header (e.g., "1.0"). Auto-initialized.
header_gfa : List[str]
List of header lines (H lines) from the GFA file. Auto-initialized.
sample_reference : Optional[str]
Reference sample name, potentially from GFA header (RS tag). Auto-initialized.
bam_segments_file : Optional[Path]
Path to the BAM file where segments (S lines) will be written. Auto-initialized.
dict_samples_chrom : defaultdict[str, OrderedDict[str, List[str]]]
Map sample names to an OrderedDict of chromosome names,
which in turn map to a list of "start\\tstop" fragment strings derived from Walk (W) lines. Auto-initialized.
dict_segments_size : defaultdict[str, int]
Map segment IDs and their length (in base pairs). Auto-initialized.
dict_segments_samples : defaultdict[str, List[str]]
Map segment IDs to a list of sample identifiers ("sample;chromosome;haplotype")
that contain the segment. Auto-initialized.
dict_samples_bed : defaultdict[str, OrderedDict[str, Path]]
(Note: This attribute is not directly populated by the current parsing logic.
It was likely intended to track the paths of generated BED files.
The `AsyncBedWriter` internally manages these paths. This attribute might be
redundant or used for post-processing tracking). Auto-initialized.
works_path : Optional[Path]
Path to the working directory (e.g., ".../{gfa_name}_GraTools_INDEX"). Auto-initialized.
bed_path : Optional[Path]
Path to the subdirectory for BED files within `works_path`. Auto-initialized.
bam_path : Optional[Path]
Path to the subdirectory for BAM files within `works_path`. Auto-initialized.
found_minigraph : bool
Flag indicating if a sample named 'MINIGRAPH' (case-insensitive) was found in Walk lines. Defaults to False.
index_links : bool, optional
If True, GFA links (L lines) are stored in an SQLite database. Defaults to True.
db_links : Optional[AsyncGfaDatabase]
Asynchronous database handler for GFA links. Auto-initialized if `index_links` is True.
segment_count : int
Total number of segments (S lines) processed. Defaults to 0.
total_segment_length : int
Sum of lengths of all segments. Defaults to 0.
link_count : int
Total number of links (L lines) processed. Defaults to 0.
degrees : defaultdict[str, int]
Maps segment IDs to their degree (number of links connected). Defaults to an empty defaultdict.
walks_count : int
Total number of walks (W lines) processed. Defaults to 0.
max_walk_rank : int
Maximum number of segments in any single walk. Defaults to 0.
sum_rank0_length : int
Sum of lengths of the first segments of all walks. Defaults to 0.
input_genome_size : int
Cumulative size of all paths (sum of segment lengths along each walk). Defaults to 0.
walks_info : List[Dict[str, Any]]
List of dictionaries, each containing info for a walk (Path name, Sequence length, Num Segments). Defaults to an empty list.
inverted_links_count : int
Count of links where orientations differ (e.g., S1+ -> S2-). Defaults to 0.
negative_links_count : int
Count of links where both segments have negative orientation (S1- -> S2-). Defaults to 0.
self_links_count : int
Count of links where a segment links to itself (S1 -> S1). Defaults to 0.
isolated_segments : Set[str]
Set of segment IDs that have no links connected to them. Initialized with all segments, then linked ones removed. Defaults to an empty set.
shared_executor : Optional[ThreadPoolExecutor]
Executor for running synchronous tasks in threads. Auto-initialized.
progress : Optional[Progress]
Rich Progress instance for displaying progress. Auto-initialized.
line_type_counts : Counter
Counts of each GFA line type (H, S, L, W, P, C, E, U). Auto-initialized.
header_gfa_file : Optional[Path]
Path to where the GFA header is saved. Auto-initialized.
stats_file : Optional[Path]
Path to where GFA statistics are saved. Auto-initialized.
db_file_path : Optional[Path]
Path to the SQLite database file for links. Auto-initialized.
RE_ORIENTED_SEG_GT_LT : re.Pattern
Compiled regular expression for parsing oriented segments using '>' and '<'. Auto-initialized.
RE_ORIENTED_SEG_PLUS_MINUS : re.Pattern
Compiled regular expression for parsing oriented segments using '+' and '-'. Auto-initialized.
"""
gfa_path: Path
threads: int = 1
logger: logging.Logger = field(
default_factory=lambda: logging.getLogger("GraTools")
)
gfa_name: Optional[str] = None
version: Optional[str] = None
header_gfa: List[str] = field(default_factory=list, repr=True)
sample_reference: Optional[str] = None
bam_segments_file: Optional[Path] = None
# Dictionaries populated during parsing
dict_samples_chrom: defaultdict = field(
default_factory=lambda: defaultdict(OrderedDict), repr=False
# Key: sample, Value: OrderedDict[chrom, list_of_fragments]
)
dict_segments_size: defaultdict = field(
default_factory=lambda: defaultdict(int), repr=False # Key: seg_id, Value: length
)
dict_segments_samples: defaultdict = field(
default_factory=lambda: defaultdict(list), repr=False # Key: seg_id, Value: list of "sample;chrom;haplo"
)
# dict_samples_bed is related to AsyncBedWriter output, not directly populated here in GFA parsing
# It seems more like a tracker of where AsyncBedWriter IS writing files.
# The GFA class _generates_ data for AsyncBedWriter.
dict_samples_bed: defaultdict = field(
# This seems to track which BED files are expected/created by AsyncBedWriter.
default_factory=lambda: defaultdict(OrderedDict), repr=False
# Key: sample, Value: OrderedDict[chrom, path_to_bed_for_chrom_walk]
)
# Paths
works_path: Optional[Path] = None
bed_path: Optional[Path] = None
bam_path: Optional[Path] = None
found_minigraph: bool = False
index_links: bool = False # Flag to control link indexing
db_links: Optional[AsyncGfaDatabase] = None # Instantiated in post_init
# Statistics counters
segment_count: int = 0
total_segment_length: int = 0
link_count: int = 0
degrees: defaultdict = field(default_factory=lambda: defaultdict(int), repr=False) # segment_id -> degree
walks_count: int = 0
max_walk_rank: int = 0 # Max number of segments in a walk
sum_rank0_length: int = 0 # Sum of lengths of first segments of walks
input_genome_size: int = 0
walks_info: List[Dict[str, Any]] = field(default_factory=list, repr=False)
inverted_links_count: int = 0
negative_links_count: int = 0
self_links_count: int = 0
isolated_segments: set = field(default_factory=set, repr=False) # Store segment IDs
# Internal operational attributes (initialized in post_init)
shared_executor: Optional[ThreadPoolExecutor] = None
progress: Optional[Progress] = None
line_type_counts: Counter = field(default_factory=Counter)
header_gfa_file: Optional[Path] = None
stats_file: Optional[Path] = None
db_file_path: Optional[Path] = None
RE_ORIENTED_SEG_GT_LT: re.Pattern = field(init=False, repr=False)
RE_ORIENTED_SEG_PLUS_MINUS: re.Pattern = field(init=False, repr=False)
def __post_init__(self) -> None:
"""
Initializes paths, directories, logger, database, and then triggers GFA processing.
"""
start_time = datetime.now()
if not self.gfa_path.exists() or not self.gfa_path.is_file():
self.logger.error(
f"GFA file '{self.gfa_path}' does not exist or is not a file."
)
raise FileNotFoundError(f"GFA file not found: {self.gfa_path}") # Fail fast
# Determine GFA name
# Using removesuffix for cleaner extension removal
name_to_process = self.gfa_path.name
if name_to_process.endswith(".gfa.gz"):
self.gfa_name = name_to_process.removesuffix(".gfa.gz")
elif name_to_process.endswith(".gfa"):
self.gfa_name = name_to_process.removesuffix(".gfa")
else:
self.gfa_name = self.gfa_path.stem # Fallback, might include partial extensions
self.logger.warning(
f"GFA file '{self.gfa_path.name}' has an unusual extension. Using stem: '{self.gfa_name}'.")
# Setup working paths
self.works_path = self.gfa_path.parent / f"{self.gfa_name}_GraTools_INDEX"
self.works_path.mkdir(exist_ok=True, parents=True)
self.bed_path = self.works_path / "bed_files"
self.bed_path.mkdir(exist_ok=True, parents=True)
self.bam_path = self.works_path / "bam_files"
self.bam_path.mkdir(exist_ok=True, parents=True)
self.bam_segments_file = self.bam_path / f"{self.gfa_name}.bam"
self.header_gfa_file = self.works_path / f"header_{self.gfa_name}.txt"
self.stats_file = self.works_path / f"stats_{self.gfa_name}.txt"
self.db_file_path = self.works_path / f"links_{self.gfa_name}.db"
# Precompile regex for optimization
self.RE_ORIENTED_SEG_GT_LT = re.compile(r"([<>])([^<>]+)")
self.RE_ORIENTED_SEG_PLUS_MINUS = re.compile(r"(.+?)([+-])")
# Initialize ThreadPoolExecutor.
self.shared_executor = ThreadPoolExecutor(
max_workers=self.threads if self.threads > 0 else None) # None for default workers
if self.index_links:
self.db_links = AsyncGfaDatabase(
db_file=self.db_file_path, # Use the defined path
timeout=5000.0 / 1000.0,
# Convert ms to seconds if timeout arg is seconds (original was 5000, assuming ms for busy_timeout)
# AsyncGfaDatabase timeout is in seconds for aiosqlite.connect
)
self.progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%",
"{task.completed}/{task.total} lines",
TimeElapsedColumn(),
TimeRemainingColumn(),
refresh_per_second=1, # Higher refresh can be costly
transient=True, # Progress bar disappears after completion
console=shared_console
)
# Trigger GFA processing and subsequent steps
self.run() # This executes the main parsing and post-processing logic
# These are called after self.run() finishes its async and threaded tasks.
self.save_header()
self.tag_bam() # Assumes self.bam_segments_file is created by parse_gfa (via _parse_segment)
self.save_samples_chrom() # Uses dict_samples_chrom populated by _parse_walk_line
self.compute_statistics() # Uses data populated during parsing
start_time = datetime.now()
elapsed_time = datetime.now() - start_time
self.logger.info(f"GFA processing and indexing for '{self.gfa_name}' completed in: {elapsed_time}")
[docs]
def save_samples_chrom(self) -> None:
"""
Save sample-chromosome-fragment information to 'samples_chrom.txt'.
This data is derived from GFA Walk (W) lines.
"""
output_path = self.works_path / "samples_chrom.txt"
self.logger.info(f"Saving sample-chromosome fragment data to '{output_path.name}'")
try:
with open(output_path, "w", encoding="utf-8") as f:
# Sort samples for consistent output
for sample in natsorted(self.dict_samples_chrom.keys()):
dict_chroms = self.dict_samples_chrom[sample]
# Sort chromosomes for consistent output
for chrom in natsorted(dict_chroms.keys()):
fragments_list = dict_chroms[chrom] # list of "start\tstop" strings
for fragment_info_str in fragments_list:
# fragment_info_str is "start\tstop"
f.write(f"{sample}\t{chrom}\t{fragment_info_str}\n")
except IOError as e:
self.logger.error(f"Failed to save samples_chrom.txt file '{output_path}': {e}")
def _process_single_bed(self, bed_path: Path, task_id: TaskID) -> None:
"""
Sorts a single BED file in place using BedTool.
Updates a Rich progress bar task.
This method is intended to be run in a separate thread or process.
"""
self.logger.debug(f"Sorting BED file: {bed_path.name}")
try:
tmp_path = bed_path.with_suffix(".sorted.bed")
sample_bed = BedTool(bed_path.as_posix()).sort(output=tmp_path.as_posix())
tmp_path.replace(bed_path)
self.logger.debug(f"Finished sorting BED file: {bed_path.name}")
except Exception as e: # Bedtools can raise various errors
self.logger.error(f"Error sorting BED file '{bed_path.name}': {e}", exc_info=True)
# How to signal error to progress? Maybe update description.
if self.progress:
self.progress.update(task_id, description=f"[red]Error sorting {bed_path.name}")
return # Exit on error for this BED file
if self.progress: # Check if progress object exists (it should if called from run())
self.progress.update(task_id, advance=1)
async def _read_gfa_lines_async(self):
"""
Asynchronously reads lines from a GFA file, handling both plain text and gzipped files.
Yields lines one by one.
"""
# REFACTOR_SUGGESTION: Consider making file opening mode ('rt' vs 'rb') more explicit.
# For gzip, 'rt' implies text mode decoding. For aiofiles, 'r' implies text mode with default encoding.
# Specifying encoding='utf-8' is often a good practice.
# Check if GFA path ends with .gz (more robust than checking suffix list)
is_gzipped = self.gfa_path.name.endswith(".gz")
if is_gzipped:
self.logger.debug(f"Reading gzipped GFA file: {self.gfa_path.name} using gzip in thread.")
# gzip.open is synchronous, so run it in a thread to not block asyncio loop
try:
# asyncio.to_thread requires Python 3.9+
# For older Python, use loop.run_in_executor(None, gzip.open, self.gfa_path, "rt", encoding="utf-8")
f = await asyncio.to_thread(gzip.open, self.gfa_path, "rt", encoding="utf-8")
except Exception as e:
self.logger.error(f"Failed to open gzipped file {self.gfa_path.name} with gzip: {e}")
return # Stop generation
else:
self.logger.debug(f"Reading plain text GFA file: {self.gfa_path.name} using aiofiles.")
try:
f = await aiofiles.open(self.gfa_path, mode="r", encoding="utf-8")
except Exception as e:
self.logger.error(f"Failed to open plain text file {self.gfa_path.name} with aiofiles: {e}")
return # Stop generation
try:
# Iterate over lines from the opened file object (f)
# For gzipped files opened via asyncio.to_thread, this iteration happens within the thread.
# For aiofiles, `async for line in f` is the standard async iteration.
if hasattr(f, "__aiter__"): # Check if it's an async iterator (aiofiles)
async for line in f:
yield line
# Add a small sleep periodically if line processing is CPU intensive AND inside the async for loop
await asyncio.sleep(0) # uncomment if GFA processing per line becomes heavy
else: # Synchronous iterator (gzip.open via to_thread)
for line in f: # This loop runs inside the thread for gzip
yield line
# No asyncio.sleep(0) here as it would sleep in the worker thread, not yield to event loop.
# The `await asyncio.to_thread` call itself yields.
finally:
self.logger.debug(f"Closing GFA file: {self.gfa_path.name}")
if hasattr(f, "close") and callable(f.close):
if is_gzipped: # If from gzip.open
await asyncio.to_thread(f.close)
else: # If from aiofiles.open (which returns an async file object)
await f.close()
[docs]
async def parse_gfa(self) -> None:
"""
Parses the GFA file line by line, processing headers, segments, links, and walks.
Segments are written to a BAM file. Links are (optionally) stored in an async SQLite DB.
Walk information is used to generate data for BED files, written by AsyncBedWriter.
"""
# Standard GFA header for output BAM if input GFA has no SQ lines.
# Pysam's AlignmentFile requires a header. If GFA S lines imply @SQ, it would be better.
# For now, a minimal header.
bam_header = {
'HD': {'VN': '1.8', 'SO': 'unsorted'}, # SO: Coordinate or unsorted, depending on S lines
'SQ': []
# Placeholder for sequence dictionary. Ideally, populate from S lines if reference context is known.
}
# Pre-calculate total lines for progress bar (can be slow for very large files)
total_lines = 0
try:
open_fn = gzip.open if self.gfa_path.name.endswith(".gz") else open
with open_fn(self.gfa_path, "rt", encoding="utf-8") as f_count:
total_lines = sum(1 for _ in f_count)
self.logger.info(f"Starting GFA parsing for: '{self.gfa_name}' with {total_lines} lines")
except Exception as e:
self.logger.warning(f"Could not count lines in GFA file for progress bar: {e}. Progress may be inaccurate.")
total_lines = 1 # Avoid division by zero if progress bar uses it
# Initialize AsyncBedWriter
bed_writer = AsyncBedWriter(
bed_dir=self.bed_path,
batch_size=100000, # Tunable parameter
progress=self.progress # Pass Rich progress instance if needed by BedWriter
)
bed_writer.start() # Start its writer loop
# Connect to AsyncGfaDatabase if indexing links
if self.index_links and self.db_links:
await self.db_links.connect()
links_batch: List[LinkInfo] = [] # Use List for type hinting, will store LinkInfo
link_batch_size = 100000 # How many links to batch before writing to DB
# Add task to Rich Progress
parse_task_id = self.progress.add_task(
"GFA parsing...", total=total_lines, visible=True
)
# Open BAM file for writing segments
# pysam.AlignmentFile is synchronous. Writes happen in the main async thread.
# If BAM writing is slow, it can block the event loop.
# OPTIMIZATION_SUGGESTION: For very high throughput GFA S lines,
# consider batching AlignedSegment objects and writing them in a separate thread
# using `await asyncio.to_thread(bam_file_out.write, batched_segment)`.
try:
with AlignmentFile(
self.bam_segments_file.as_posix(), "wb", header=bam_header, threads=1
# Use self.threads for multi-threading if pysam supports for write
) as bam_file_out, self.progress: # `with self.progress` starts/stops the Rich Progress context
current_line_type = None
lines_processed_count = 0
async for line_content in self._read_gfa_lines_async():
lines_processed_count += 1
self.progress.advance(parse_task_id) # Advance progress for each line
line_content = line_content.strip() # Strip newline and surrounding whitespace
if not line_content or line_content.startswith("#"): # Skip empty or comment lines
continue
fields = line_content.split("\t")
line_type = fields[0]
if line_type != current_line_type:
current_line_type = line_type
self.progress.update(
parse_task_id,
description=f"GFA parsing: {current_line_type}"
)
self.line_type_counts[line_type] += 1
match line_type:
case "H": # Header line
self.header_gfa.append(line_content)
# Basic header parsing for version and reference sample (if present)
# Example H line: H VN:Z:1.0 RS:Z:ref_sample
for field in fields[1:]:
if field.startswith("VN:Z:"):
self.version = field[5:]
elif field.startswith(
"RS:Z:"): # Reference Sample tag (custom or specific GFA variant?)
self.sample_reference = field[5:]
case "S": # Segment line
self._parse_segment(fields, bam_file_out) # Pass fields list
case "L": # Links line
self.link_count += 1
if self.index_links and self.db_links:
link_info = self._parse_link_line(fields) # Pass fields list
if link_info:
links_batch.append(link_info)
if len(links_batch) >= link_batch_size:
await self.db_links.batch_insert_links(list(links_batch)) # Pass a copy
links_batch.clear()
await asyncio.sleep(0) # Yield control after a DB batch
case "W": # Walk line
# W <sample_name> <haplotype_index> <chr_name> <chr_start> <chr_stop> <segment_path>
# For GFA2 OrderedGroup (OG lines), parsing would be different. Assuming GFA1 W.
parse_result = self._parse_walk_line(fields) # Pass fields list
if parse_result:
sample_name, bed_lines_for_sample = parse_result
if bed_lines_for_sample: # Ensure there are lines to enqueue
bed_writer.enqueue(sample_name, bed_lines_for_sample)
# Yield control if many walks are processed rapidly
if self.walks_count % 10 == 0: # Every 2 walks
await asyncio.sleep(0)
# Other GFA line types (P - Path, C - Containment, E - Edge (GFA2), U - Gap (GFA2), etc.)
# Add parsing logic here if needed. For now, they are counted by line_type_counts.
# Periodically yield control to event loop, especially if line processing is fast
if lines_processed_count % 100 == 0: # Every 100 lines
await asyncio.sleep(0)
except Exception as e:
self.logger.error(f"Error during GFA parsing main loop: {e}", exc_info=True)
# Ensure progress bar is stopped or marked as errored if possible
self.progress.update(parse_task_id, description=f"[red]Error during GFA parsing: {e}",
completed=lines_processed_count)
# Proceed to shutdown, resources might be in an inconsistent state.
finally:
# Finalize operations after loop (or if error occurred)
self.progress.remove_task(parse_task_id) # Or stop if not transient
self.logger.info(f"Finished reading GFA file. Processed {lines_processed_count} lines.")
if self.found_minigraph:
self.logger.warning(
"Sample 'MINIGRAPH' (or similar) detected and ignored. "
"See Cactus issue https://github.com/ComparativeGenomicsToolkit/cactus/pull/1050 for context."
)
# Flush any remaining links to the database
if self.index_links and self.db_links and links_batch:
self.logger.debug(f"Inserting final batch of {len(links_batch)} links into database.")
await self.db_links.batch_insert_links(list(links_batch))
# Shutdown AsyncGfaDatabase (waits for queue, creates indexes, closes connection)
if self.db_links:
self.logger.info(
f"Closing links database and saved to '{self.db_file_path.name if self.db_file_path else 'N/A'}'.")
await self.db_links.close()
# Shutdown AsyncBedWriter (waits for queue, flushes all buffers, closes files)
self.logger.info("Shutting down BED writer...")
await bed_writer.shutdown()
[docs]
def run(self):
"""
Synchronous entry point to orchestrate GFA parsing and BED file sorting.
Sets up an asyncio event loop and runs the asynchronous `parse_gfa` method.
Then, sorts the generated BED files using a ThreadPoolExecutor.
"""
# Create and configure a new asyncio event loop
loop = asyncio.new_event_loop() # uvloop policy is set globally, so new_event_loop() should pick it up.
asyncio.set_event_loop(loop)
# loop = asyncio.get_event_loop() # Get current event loop (uvloop should be set)
if self.shared_executor: # Ensure executor is set for the loop
loop.set_default_executor(self.shared_executor)
try:
# Run the asynchronous GFA parsing
loop.run_until_complete(self.parse_gfa())
self.logger.info("Async GFA parsing complete. Proceeding to sort BED files.")
# Collect paths of BED files to be sorted
# This assumes dict_samples_bed is populated correctly by where AsyncBedWriter saves files.
# TODO: AsyncBedWriter saves as self.bed_dir / f"{sample}.bed"
# So, we need to find all .bed files in self.bed_path that were potentially created.
# A more robust way: AsyncBedWriter could return a list of files it actually wrote/flushed.
# Shutdown shared thread executor
if self.shared_executor:
self.logger.debug("Shutting down shared ThreadPoolExecutor.")
self.shared_executor.shutdown(wait=True) # Wait for submitted tasks to complete
loop.close() # Only if this `run` method owns the loop lifecycle entirely.
# Assuming dict_samples_chrom keys are the samples for which BEDs might exist:
beds_to_sort = []
if self.bed_path and self.bed_path.exists():
for sample_name in self.dict_samples_chrom.keys(): # Samples derived from W lines
potential_bed_file = self.bed_path / f"{sample_name}.bed"
if potential_bed_file.exists() and potential_bed_file.stat().st_size > 0:
beds_to_sort.append(potential_bed_file)
if not beds_to_sort:
self.logger.info("No BED files found or specified for sorting.")
else:
self.logger.info(f"Found {len(beds_to_sort)} BED files to sort.")
sort_task_id = self.progress.add_task(
f"Sorting {len(beds_to_sort)} BED files",
total=len(beds_to_sort),
visible=True, # Make sure progress bar is visible
)
bed_sort_workers = min(self.threads, len(beds_to_sort))
with ThreadPoolExecutor(max_workers=bed_sort_workers) as bed_sort_specific_executor:
with self.progress:
futures = [
bed_sort_specific_executor.submit(
self._process_single_bed,
bed_file_path,
sort_task_id
)
for bed_file_path in beds_to_sort
]
# Attendre que toutes les tâches soumises soient terminées
# `wait` bloque jusqu'à ce que les futures soient complétées.
done, not_done = wait(futures)
if not_done:
self.logger.warning(f"{len(not_done)} BED sorting tasks did not complete.")
for future_not_done in not_done:
try:
future_not_done.result(timeout=0.1) # Tenter d'obtenir une exception
except Exception as e:
self.logger.error(f"Task error for uncompleted sort: {e}", exc_info=True)
self.progress.remove_task(sort_task_id)
self.logger.info("All BED file sorting tasks complete.")
except Exception as e:
self.logger.error(f"Error during GFA.run execution: {e}", exc_info=True)
finally:
self.logger.debug("GFA.run() finished.")
def _parse_walk_line(self, fields: List[str]) -> Optional[Tuple[str, List[str]]]:
"""
Parses a GFA Walk line (W line) to extract path information and generate BED-like entries.
Updates internal dictionaries related to samples, chromosomes, and segment occurrences.
Parameters
----------
fields : List[str]
A list of tab-separated fields from a W line.
Expected GFA1 format: W <sample> <hap_idx> <chr> <start> <stop> <path_spec>
Example path_spec: >seg1A+>seg1B->seg2A+ (GFA1 allows segment names with orientations)
>s1>s2<s3 (another common style where orientation is part of path_spec)
Returns
-------
Optional[Tuple[str, List[str]]]
A tuple (sample_name, list_of_bed_lines_for_this_walk) if successful.
Each bed_line string includes a newline.
Returns None if parsing fails or line is skipped (e.g., MINIGRAPH).
"""
try:
# GFA1 W line: W <Samp> <HapIdx> <SeqId> <Start> <End> <Path>
# For GFA specification, check official docs. This parsing assumes a common interpretation.
_type_line = fields[0] # "W"
sample_name = fields[1]
haplotype_index_str = fields[2] # Can be numeric or other ID
chromosome_name = fields[3]
chromosome_start_str = fields[4] # 0-based or 1-based depending on GFA source convention
chromosome_stop_str = fields[5] # End coordinate
gfa_path_specification = fields[6] # e.g., "s1+s2-s3+" or ">s1<s2>s3"
except IndexError:
self.logger.warning(f"Malformed W line (not enough fields): '{fields}'. Skipping.")
return None
if "MINIGRAPH" in sample_name.upper(): # Case-insensitive check
self.found_minigraph = True
self.logger.debug(f"Skipping W line for 'MINIGRAPH' sample: {fields}")
return None # Skip processing for MINIGRAPH sample
# Update sample-chromosome structure
# Using setdefault is Pythonic for initializing if key is missing
sample_chrom_list = self.dict_samples_chrom[sample_name].setdefault(chromosome_name, [])
sample_chrom_list.append(f"{chromosome_start_str}\t{chromosome_stop_str}")
current_segment_start_on_chr = int(chromosome_start_str) # Current genomic position for BED line
# Information for SW tag in BAM (segment_id -> list of "sample;chrom;haplo")
sample_chromosome_haplotype_id = f"{sample_name};{chromosome_name};{haplotype_index_str}"
oriented_segments_in_path: List[Tuple[str, str]] = []
if ">" in gfa_path_specification or "<" in gfa_path_specification:
matches = self.RE_ORIENTED_SEG_GT_LT.findall(gfa_path_specification) # Use pre-compiled regex
for orient_char_map, seg_id in matches:
oriented_segments_in_path.append(("+" if orient_char_map == ">" else "-", seg_id))
else:
matches = self.RE_ORIENTED_SEG_PLUS_MINUS.findall(gfa_path_specification) # Use pre-compiled regex
for seg_id, orient_char in matches:
oriented_segments_in_path.append((orient_char, seg_id))
if not oriented_segments_in_path:
self.logger.warning(f"Could not parse path spec in W line: '{gfa_path_specification}'. Skipping.")
return None
num_segments_in_walk = len(oriented_segments_in_path)
self.walks_count += 1
self.max_walk_rank = max(self.max_walk_rank, num_segments_in_walk)
current_walk_total_length_bp = 0
if oriented_segments_in_path: # Check if list is not empty
first_seg_id = oriented_segments_in_path[0][1]
# Directly use self.dict_segments_size, local alias might not give much here
first_seg_len = self.dict_segments_size.get(first_seg_id, 0)
self.sum_rank0_length += first_seg_len
# Pre-allocate list for BED lines if num_segments_in_walk is large
# Threshold for pre-allocation can be tuned, e.g., if > 1000 segments
if num_segments_in_walk > 1000: # Example threshold
bed_lines_for_this_walk: List[str] = [""] * num_segments_in_walk
use_preallocation = True
else:
bed_lines_for_this_walk: List[str] = []
use_preallocation = False
bed_field_5_fragment_id = f"{chromosome_start_str}:{chromosome_stop_str}"
# bed_field_6_haplotype_id = haplotype_index_str # Already have haplotype_index_str
# Store direct references to dictionaries to potentially reduce lookup overhead within the tight loop.
# This is a micro-optimization and might not yield significant gains unless N is extremely large.
_dict_segments_samples = self.dict_segments_samples
_dict_segments_size = self.dict_segments_size
for i, (segment_orientation_char, segment_id) in enumerate(oriented_segments_in_path):
_dict_segments_samples[segment_id].append(sample_chromosome_haplotype_id)
segment_length = _dict_segments_size.get(segment_id, 0)
# Warning for 0-length is already there, keep it.
if segment_length == 0:
self.logger.warning(
f"Segment '{segment_id}' in walk for sample '{sample_name}' has zero length "
f"or size not found. BED entry may be inaccurate."
)
current_walk_total_length_bp += segment_length
segment_end_on_chr = current_segment_start_on_chr + segment_length
# Construct BED line (f-strings are generally efficient)
# No significant optimization here typically, unless constructing millions of unique strings.
# Joining from parts `"\t".join([...])` can sometimes be faster if parts are reused, but not here.
bed_line = (
f"{chromosome_name}\t"
f"{current_segment_start_on_chr}\t{segment_end_on_chr}\t"
f"{segment_orientation_char}{segment_id}\t" # oriented_seg_id_for_bed
f"{bed_field_5_fragment_id}\t"
f"{haplotype_index_str}\n" # Directly use haplotype_index_str
)
if use_preallocation:
bed_lines_for_this_walk[i] = bed_line
else:
bed_lines_for_this_walk.append(bed_line)
# Update start position for the next segment in the walk
current_segment_start_on_chr = segment_end_on_chr
# Update total genome size covered by all walks
self.input_genome_size += current_walk_total_length_bp
# Store summary information for this walk (for statistics)
self.walks_info.append({
"PathName": f"{sample_name}_{chromosome_name}_{haplotype_index_str}", # Unique path identifier
"SequenceLength_bp": current_walk_total_length_bp,
"NumberOfSegments": num_segments_in_walk,
})
return sample_name, bed_lines_for_this_walk
def _parse_link_line(self, fields: List[str]) -> Optional[LinkInfo]:
"""
Parses a GFA Link line (L line) to extract connection information between two segments.
Updates link-related statistics.
Parameters
----------
fields : List[str]
A list of tab-separated fields from an L line.
Expected GFA1 format: L <seg1> <orient1> <seg2> <orient2> <overlap_CIGAR>
Returns
-------
Optional[LinkInfo]
A LinkInfo namedtuple containing parsed link data if successful, else None.
"""
try:
_type_line = fields[0] # "L"
seg_id_1 = fields[1]
orient_char_1 = fields[2] # "+" or "-"
seg_id_2 = fields[3]
orient_char_2 = fields[4] # "+" or "-"
# overlap_cigar = fields[5] # e.g., "0M", "10M", etc. Not used in LinkInfo current def.
except IndexError:
self.logger.warning(f"Malformed L line (not enough fields): '{fields}'. Skipping.")
return None
# Convert orientation characters to numeric values (+1 for '+', -1 for '-')
# This is a common convention if orientations need to be multiplied or compared numerically.
orient_seg_1_numeric = 1 if orient_char_1 == "+" else -1
orient_seg_2_numeric = 1 if orient_char_2 == "+" else -1
# Update link statistics
self.link_count += 1
if seg_id_1 == seg_id_2:
self.self_links_count += 1
if orient_seg_1_numeric * orient_seg_2_numeric < 0: # e.g., (+1 * -1) = -1 (orientations differ)
self.inverted_links_count += 1
if orient_seg_1_numeric == -1 and orient_seg_2_numeric == -1: # Both negative
self.negative_links_count += 1
# Update segment degrees
self.degrees[seg_id_1] += 1
self.degrees[seg_id_2] += 1
# Remove linked segments from the set of isolated segments
self.isolated_segments.discard(seg_id_1)
self.isolated_segments.discard(seg_id_2)
# Create LinkInfo tuple. The `orient_key` fields seem to duplicate `orient_seg` fields.
# If they have a distinct meaning, LinkInfo definition or parsing might need adjustment.
return LinkInfo(
seg_id_1, orient_seg_1_numeric, orient_seg_1_numeric, # Using numeric orientation for keys too
seg_id_2, orient_seg_2_numeric, orient_seg_2_numeric
)
def _parse_segment(self, fields: List[str], bam_file_out: AlignmentFile) -> None:
"""
Parses a GFA Segment line (S line) and writes it as an AlignedSegment to a BAM file.
Updates segment-related statistics.
Parameters
----------
fields : List[str]
A list of tab-separated fields from an S line.
Expected GFA1 format: S <seg_id> <sequence> [optional_tags...]
Example optional tag: LN:i:<length> (though length is usually derived from sequence)
bam_file_out : pysam.AlignmentFile
Open BAM file object for writing.
"""
try:
_type_line = fields[0] # "S"
seg_id = fields[1]
sequence = fields[2] # Sequence string
# But here, it seems sequence is expected directly.
optional_tags_list = fields[3:] # List of "TAG:TYPE:VALUE" strings
except IndexError:
self.logger.warning(f"Malformed S line (not enough fields): '{fields}'. Skipping.")
return
seg_length = len(sequence)
# Update statistics
self.dict_segments_size[seg_id] = seg_length
self.segment_count += 1
self.total_segment_length += seg_length
self.isolated_segments.add(seg_id) # Add initially, links will remove it if connected
# Create AlignedSegment object for BAM
segment_for_bam = AlignedSegment()
segment_for_bam.query_name = seg_id
segment_for_bam.query_sequence = sequence
segment_for_bam.query_qualities = qualitystring_to_array("*" * seg_length)
segment_for_bam.flag = 4 # Unmapped
segment_for_bam.reference_id = -1 # No reference
segment_for_bam.reference_start = 0 # 0-based unmapped position
segment_for_bam.mapping_quality = 0 # Unmapped
# CIGAR: For unmapped segments, often not set or set to reflect sequence length if needed by tools.
# If sequence exists, a CIGAR like "XM" (X=length) might be used. Pysam might handle this.
# For simplicity, if sequence is present, a simple "M" CIGAR.
segment_for_bam.cigartuples = [
(0, seg_length)] if seg_length > 0 else None # e.g., ( pysam.CMATCH, length_of_sequence )
# Store GFA optional tags in a custom BAM tag, e.g., ZG:Z:<concatenated_gfa_tags>
# Original code uses "SU" tag. Let's stick to that.
# The "SU" tag seems to store a comma-separated string of the GFA optional tags.
if optional_tags_list:
gfa_tags_str = ",".join(optional_tags_list)
segment_for_bam.set_tag("SU", gfa_tags_str, value_type='Z') # 'Z' for string
# Write to BAM file
# This is a synchronous call. If it's slow, it blocks the async event loop.
try:
bam_file_out.write(segment_for_bam)
except Exception as e: # Catch errors from pysam write (e.g., header issues, malformed segment)
self.logger.error(f"Error writing segment '{seg_id}' to BAM: {e}", exc_info=True)
[docs]
def tag_bam(self) -> None:
"""
Tags the generated BAM segments file with sample walk information (SW tag).
This uses the `GratoolsBam` class to perform the tagging.
The original BAM segments file is overwritten with the tagged version.
"""
self.logger.info(f"Initiating BAM tagging for '{self.bam_segments_file.name}'.")
if not self.bam_segments_file or not self.bam_segments_file.exists():
self.logger.error("BAM segments file not found. Cannot perform tagging.")
return
if not self.dict_segments_samples:
self.logger.warning("dict_segments_samples is empty. No SW tags will be added.")
# Still proceed to index, as an untagged but indexed BAM might be expected.
# Instantiate GratoolsBam for tagging operation
bam_manager = GratoolsBam(
bam_path=self.bam_segments_file,
threads=self.threads,
logger=self.logger, # Pass parent logger
tagging=False # Indicates that indexing might be needed after tagging
)
# The tag() method is expected to create a new tagged BAM,
# then replace the original self.bam_segments_file.
# It should also handle re-indexing if `tagging=True` implies index creation/update.
try:
# The returned path from tag() should be the path to the (now tagged) bam_segments_file
updated_bam_path = bam_manager.tag(dict_segments_samples=self.dict_segments_samples)
if updated_bam_path != self.bam_segments_file:
self.logger.warning(f"BAM tagging returned a new path '{updated_bam_path}', "
f"but expected overwrite of '{self.bam_segments_file.name}'. Check GratoolsBam.tag logic.")
except Exception as e:
self.logger.error(f"Error during BAM tagging process: {e}", exc_info=True)
def _get_connected_component(self, start_seg: str, cursor: sqlite3.Cursor) -> Set[str]:
"""
Executes a recursive SQL query via sqlite3 to retrieve all segments
connected to `start_seg`, using an existing database cursor.
This assumes bi-directional links (if A links to B, B effectively links to A for component finding).
Parameters
----------
start_seg : str
The starting segment ID for component traversal.
cursor : sqlite3.Cursor
An active SQLite database cursor to use for the query.
Returns
-------
Set[str]
A set of segment IDs belonging to the same connected component as `start_seg`.
"""
# The query finds all segments reachable from start_seg.
# It considers links in both directions (seg_id_1 -> seg_id_2 and seg_id_2 -> seg_id_1)
# by UNIONing them in the recursive part.
query = """
WITH RECURSIVE connected_component(segment_id) AS (
VALUES(?) -- Base case: the starting segment
UNION
-- Recursively find segments linked from current component segments
SELECT links.seg_id_2 FROM links
JOIN connected_component ON links.seg_id_1 = connected_component.segment_id
UNION -- Also consider links in the other direction
SELECT links.seg_id_1 FROM links
JOIN connected_component ON links.seg_id_2 = connected_component.segment_id
)
SELECT segment_id FROM connected_component;
"""
# OPTIMIZATION_SUGGESTION: For very large graphs, ensure SQLite's query planner handles
# this recursive CTE efficiently. Indexing on (seg_id_1, seg_id_2) and (seg_id_2, seg_id_1)
# or individual indexes on seg_id_1 and seg_id_2 (as created by AsyncGfaDatabase) is crucial.
try:
cursor.execute(query, (start_seg,))
rows = cursor.fetchall()
return {row[0] for row in rows} # Convert list of tuples to set of IDs
except sqlite3.Error as e:
self.logger.error(f"SQLite error getting connected component for '{start_seg}': {e}", exc_info=True)
return {start_seg} # Return at least the start segment on error
[docs]
def compute_statistics(self) -> Dict[str, Any]: # Return type more specific if possible
"""
Computes various statistics about the parsed GFA graph, including segment counts,
lengths, link properties, and connectivity (if link indexing is enabled).
Saves these statistics to a file.
Returns
-------
Dict[str, Any]
A dictionary containing all computed statistics, categorized.
"""
self.logger.info("Computing GFA graph statistics...")
current_progress = self.progress
# Define number of steps for progress bar
# REFACTOR_SUGGESTION: Could make this more dynamic or a constant.
total_steps_stats = 7
stats_task_id = current_progress.add_task(
"[bold blue]Computing graph statistics...", total=total_steps_stats
)
stats_results: Dict[str, Any] = {}
with current_progress: # Manage progress bar display
try:
# Step 1: Basic Graph Stats
current_progress.update(
stats_task_id,
description="[cyan]Stats: Basic Graph Metrics", advance=0 # No advance yet
)
avg_segment_length = (
self.total_segment_length / self.segment_count
if self.segment_count else 0.0
)
# Avg degree = (2 * num_edges) / num_nodes for undirected graph
avg_degree = (
(2 * self.link_count) / self.segment_count if self.segment_count else 0.0
)
max_degree = max(self.degrees.values()) if self.degrees else 0
# Node-Edge Ratio (or segment-link ratio)
ne_ratio = self.segment_count / self.link_count if self.link_count else 0.0
graph_size_bp = self.total_segment_length # Total length of all unique segments
# input_genome_size is sum of segment lengths along all paths (can be > graph_size_bp)
compress_ratio = self.input_genome_size / graph_size_bp if graph_size_bp else 0.0
current_progress.update(stats_task_id, advance=1)
# Step 2: Segment Length Stats
current_progress.update(stats_task_id, description="[cyan]Stats: Segment Lengths")
node_lengths = list(self.dict_segments_size.values())
median_node_length = float(median(node_lengths)) if node_lengths else 0.0
avg_top5_percent_length = 0.0
median_top5_percent_length = 0.0
if node_lengths:
# Calculate top 5% longest nodes
num_top_5_percent = max(1, int(0.05 * len(node_lengths)))
top_5_percent_nodes_lengths = sorted(node_lengths, reverse=True)[:num_top_5_percent]
if top_5_percent_nodes_lengths:
avg_top5_percent_length = float(mean(top_5_percent_nodes_lengths))
median_top5_percent_length = float(median(top_5_percent_nodes_lengths))
# Histogram of node lengths
# Bins: [0-500), [500-1000), [1000-2000), [2000-inf)
length_bins = [0, 500, 1000, 2000, float("inf")]
bin_labels = ["0-499bp", "500-999bp", "1000-1999bp", "2000+bp"]
hist_counts, _ = histogram(node_lengths, bins=length_bins)
segment_length_bin_counts = dict(zip(bin_labels, map(int, hist_counts))) # Ensure int for counts
current_progress.update(stats_task_id, advance=1)
# Step 3: Similarity (num unique samples per seg) & Depth (total occurrences per seg)
current_progress.update(stats_task_id, description="[cyan]Stats: Segment Similarity/Depth")
similarities = [] # Number of unique samples per segment
depths = [] # Total occurrences of a segment across all walks/samples
for seg_id, sample_occurrences_list in self.dict_segments_samples.items():
# sample_occurrences_list contains "sample;chrom;haplo" strings
unique_samples_for_seg = set(occ.split(";")[0] for occ in sample_occurrences_list)
similarities.append(len(unique_samples_for_seg))
depths.append(len(sample_occurrences_list))
sim_mean = float(mean(similarities)) if similarities else 0.0
sim_median = float(median(similarities)) if similarities else 0.0
sim_std = float(std(similarities)) if similarities else 0.0
depth_mean = float(mean(depths)) if depths else 0.0
depth_median = float(median(depths)) if depths else 0.0
depth_std = float(std(depths)) if depths else 0.0
current_progress.update(stats_task_id, advance=1)
# Step 4: Graph Density
current_progress.update(stats_task_id, description="[cyan]Stats: Graph Density")
# Density for an undirected graph: 2*L / (N*(N-1)) where L=links, N=segments
if self.segment_count > 1:
graph_density = (2 * self.link_count) / (
self.segment_count * (self.segment_count - 1)
)
else:
graph_density = 0.0 # Undefined or 0 for single/no node graphs
current_progress.update(stats_task_id, advance=1)
# Step 5: Connectivity (requires SQLite DB of links)
current_progress.update(stats_task_id, description="[cyan]Stats: Connectivity (DB query)")
num_connected_components = 0
largest_cc_size_bp = 0 # Size in base pairs of the largest CC
# disconnected_components_info = [] # List of (size_bp, num_nodes) for smaller CCs
num_disconnected_components_count = 0 # Number of CCs excluding the largest one
total_length_disconnected_bp = 0
if self.index_links and self.db_links and self.db_links.db_file and self.db_links.db_file.exists():
conn_sqlite = None
try:
# Use a synchronous connection for these stats queries, as they are complex.
# The AsyncGfaDatabase is primarily for async inserts.
conn_sqlite = sqlite3.connect(self.db_links.db_file.as_posix(), timeout=0.5)
cursor = conn_sqlite.cursor()
# PRAGMAs for read performance (WAL should be set by AsyncGfaDatabase)
cursor.execute("PRAGMA journal_mode=WAL;") # Already set by writer
cursor.execute("PRAGMA synchronous=OFF;") # Safer than OFF for reads if WAL
cursor.execute("PRAGMA cache_size = -8000000;") # 2GB cache (page size * num_pages)
cursor.execute("PRAGMA temp_store = MEMORY;")
all_components: List[Set[str]] = []
visited_segments_for_cc = set()
# Iterate over all known segment IDs to find all connected components
for seg_id_start_cc in self.dict_segments_size.keys():
if seg_id_start_cc not in visited_segments_for_cc:
component = self._get_connected_component(seg_id_start_cc, cursor)
if component:
all_components.append(component)
visited_segments_for_cc.update(component)
if all_components:
num_connected_components = len(all_components)
component_lengths_bp = [
sum(self.dict_segments_size.get(s, 0) for s in comp)
for comp in all_components
]
largest_cc_size_bp = max(component_lengths_bp) if component_lengths_bp else 0
# Find largest component by length to define "disconnected" relative to it
largest_comp_idx = component_lengths_bp.index(largest_cc_size_bp)
for i, comp_len in enumerate(component_lengths_bp):
if i != largest_comp_idx:
num_disconnected_components_count += 1
total_length_disconnected_bp += comp_len
except sqlite3.Error as e:
self.logger.error(f"SQLite error during connectivity stats: {e}", exc_info=True)
finally:
if conn_sqlite:
conn_sqlite.close()
else:
self.logger.info(
"Link indexing not enabled or DB not available; "
"skipping database-dependent connectivity statistics."
)
current_progress.update(stats_task_id, advance=1)
# Step 6: Final simple stats from parsed data
current_progress.update(stats_task_id, description="[cyan]Stats: Finalizing")
# Dead ends: segments with degree 1
dead_ends_count = sum(1 for degree_val in self.degrees.values() if degree_val == 1)
# Isolated segments: segments with degree 0 (from self.isolated_segments set)
num_isolated_segments_count = len(self.isolated_segments)
current_progress.update(stats_task_id, advance=1)
# Consolidate all stats into a dictionary
stats_results = {
"Graph Overview": {
"GFA File Name": self.gfa_name,
"GFA Version": self.version,
"Total Segments (S lines)": self.segment_count,
"Total Links (L lines)": self.link_count,
"Total Walks (W lines)": self.walks_count,
"Unique Samples in Walks": len(self.dict_samples_chrom),
},
"Segment Statistics": {
"Total Segment Length (bp)": self.total_segment_length, # This is graph_size_bp
"Average Segment Length (bp)": avg_segment_length,
"Median Segment Length (bp)": median_node_length,
"Avg Length of Top 5% Longest Segments (bp)": avg_top5_percent_length,
"Median Length of Top 5% Longest Segments (bp)": median_top5_percent_length,
"Segment Length Distribution": segment_length_bin_counts,
},
"Link Statistics": {
"Max Segment Degree": max_degree,
"Average Segment Degree": avg_degree,
"Self-Links (S1 -> S1)": self.self_links_count,
"Inverted Links (S1+ -> S2-)": self.inverted_links_count,
"Both Negative Links (S1- -> S2-)": self.negative_links_count,
},
"Path (Walk) Statistics": {
"Total Length of All Paths (bp, input_genome_size)": self.input_genome_size,
"Graph Compression Ratio (input_genome_size / graph_size_bp)": compress_ratio,
"Max Segments in a Single Walk": self.max_walk_rank,
"Sum of First Segment Lengths in Walks": self.sum_rank0_length,
},
"Segment Sharing & Depth": {
"Avg Unique Samples per Segment (Similarity Mean)": sim_mean,
"Median Unique Samples per Segment (Similarity Median)": sim_median,
"StdDev Unique Samples per Segment (Similarity Std)": sim_std,
"Avg Occurrences per Segment (Depth Mean)": depth_mean,
"Median Occurrences per Segment (Depth Median)": depth_median,
"StdDev Occurrences per Segment (Depth Std)": depth_std,
},
"Graph Structure": {
"Graph Density": graph_density,
"Segments/Links Ratio": ne_ratio,
"Dead-End Segments (degree 1)": dead_ends_count,
"Isolated Segments (degree 0)": num_isolated_segments_count,
"Number of Connected Components (CCs)": num_connected_components,
"Largest CC Size (bp)": largest_cc_size_bp,
"Number of Disconnected CCs (excluding largest)": num_disconnected_components_count,
"Total Length of Disconnected CCs (bp)": total_length_disconnected_bp,
},
}
# Final progress update for formatting/saving
current_progress.update(stats_task_id, description="[green]Stats: Formatting & Saving", advance=1)
except Exception as e:
self.logger.error(f"Error during statistics computation: {e}", exc_info=True)
current_progress.update(stats_task_id, description=f"[red]Error computing stats: {e}")
# stats_results will contain whatever was computed before the error.
# Flatten the dictionary for DataFrame and CSV output
flat_stats_for_df = []
for category, metrics_dict in stats_results.items():
if isinstance(metrics_dict, dict):
for metric_name, value in metrics_dict.items():
flat_stats_for_df.append((category, metric_name, value))
else: # Should not happen with current stats_results structure
flat_stats_for_df.append((category, "Value", metrics_dict))
df_stats = DataFrame(flat_stats_for_df, columns=["Category", "Metric", "Value"])
# Apply smart formatting for numbers in the 'Value' column before saving
def format_value(x):
if isinstance(x, dict): # For binned counts
return ", ".join(f"{k}: {v}" for k, v in x.items())
if isinstance(x, float):
if 0 < abs(x) < 1e-3 or abs(x) > 1e6: # Use scientific notation for very small or large floats
return f"{x:.2e}"
return f"{x:.2f}" # Standard float formatting
if isinstance(x, int):
return f"{x:,}" # Add comma for thousands separator
return str(x) # Default to string for others
df_stats["Value"] = df_stats["Value"].apply(format_value)
if self.stats_file:
try:
df_stats.to_csv(self.stats_file, sep="\t", index=False)
self.logger.info(f"Saved GFA statistics to '{self.stats_file.name}'")
except IOError as e:
self.logger.error(f"Failed to save statistics CSV '{self.stats_file}': {e}")
return stats_results