# Standard library imports
import asyncio
import gzip
import logging
import sqlite3
import subprocess
import tempfile
import re
from collections import Counter, OrderedDict, defaultdict, namedtuple
from concurrent.futures import ThreadPoolExecutor, wait, as_completed
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Set, TextIO
import heapq
from statistics import mean, median
from itertools import chain
from copy import copy
# Third-party imports
import aiofiles
import aiosqlite
from numpy import histogram
import numpy as np
from pandas import set_option, DataFrame, merge
import uvloop
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from pybedtools import BedTool, cleanup
from pysam import AlignedSegment, AlignmentFile, index, qualitystring_to_array, view
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
TaskID,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
from rich import box
from natsort import natsorted
# Local application/library specific imports
from .logger_config import shared_console # Assuming shared_console is used for rich printing
from .useful_function import RC_TRANS, reverse_complement_string
# Apply uvloop event loop policy
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
# Configure display precision of floats for pandas
set_option("display.precision", 2)
LinkInfo = namedtuple(
"LinkInfo",
(
"seg_id_1",
"orient_seg_1",
"orient_key_seg_1",
"seg_id_2",
"orient_seg_2",
"orient_key_seg_2",
),
)
"""
Information about a link between two segments.
Attributes
----------
seg_id_1 : str
Identifier of the first segment.
orient_seg_1 : int
Orientation of the first segment (+1 or -1).
orient_key_seg_1 : int
Orientation key for the first segment (often the same as orient_seg_1).
seg_id_2 : str
Identifier of the second segment.
orient_seg_2 : int
Orientation of the second segment (+1 or -1).
orient_key_seg_2 : int
Orientation key for the second segment (often the same as orient_seg_2).
"""
[docs]
@dataclass
class SubGraph:
"""
A class representing a subgraph in a genomic analysis pipeline.
It handles operations like BED file intersection, GFA walk/link/segment building,
and FASTA sequence generation for a specific sample or region.
Attributes
----------
bam_path : Path
The file path to the BAM file.
bed_path : Path
The file path to the BED file.
logger : logging.Logger
The logger instance for logging messages. Defaults to a logger named "GraTools".
sample_name : Optional[str]
The name of the sample, derived from `bed_path`. Default to None.
sample_name_query : Optional[str]
The name of the sample being queried. Default to None.
chromosome_query : Optional[str]
The chromosome name for the query. Default to None.
start_query : Optional[int]
The start position on the chromosome for the query. Default to None.
stop_query : Optional[int]
The stop position on the chromosome for the query. Default to None.
offset_first : int
The offset for the first segment in the query region. Default to 0.
offset_last : int
The offset for the last segment in the query region. Default to 0.
add_start_bases_first_segment : int
Additional start bases for the first segment (usage seems specific, consider clarifying). Default to 0.
intersect_bed : Optional[BedTool]
The BedTool object for intersected BED regions. Default to None.
segment_id_set : Set[str]
A set storing unique segment IDs encountered during walk building.
segment_id_first_query : Optional[str]
The ID of the first segment in the query region. Default to None.
segment_id_first_strand : Optional[str]
The strand ('+' or '-') of the first segment ID in the query. Default to None.
segment_id_last_query : Optional[str]
The ID of the last segment ID in the query region. Default to None.
segment_id_last_strand : Optional[str]
The strand ('+' or '-') of the last segment ID in the query. Default to None.
segment_id_first : Optional[str]
The first segment ID encountered for the current sample (might be same as query). Default to None.
segment_id_last : Optional[str]
The last segment ID encountered for the current sample (might be same as query). Default to None.
works_path : Optional[Path]
The working directory path. Defaults to None.
merge : Optional[int]
Merge distance parameter for BED operations (e.g., `bedtools merge -d`). Default to None.
build_fasta_flag : Optional[bool]
Flag indicating whether FASTA sequences should be built. Default to False.
gfa_walk_list : List[str]
A list of GFA walk strings (W lines). Default to an empty list.
gfa_link_list : List[str]
A list of GFA link strings (L lines). Default to an empty list.
gfa_segment_list : List[str]
A list of GFA segment strings (S lines). Default to an empty list.
dict_segments_samples : defaultdict[str, List[str]]
A dictionary mapping segment IDs to a list of sample identifiers. Default to an empty defaultdict.
dict_segments_sequence : defaultdict[str, str]
A dictionary mapping segment IDs to their sequences. Default to an empty defaultdict.
sequences_list : List[SeqRecord]
A list of Biopython SeqRecord objects for generated FASTA sequences. Default to an empty list.
progress_dict : Optional[Dict]
A dictionary to track progress, typically for multi-processing. Default to None.
task_id : Optional[TaskID]
The task ID for progress tracking with `rich.progress`. Default to None.
regions : Optional[List[Dict[str, Any]]]
A list of regions (dictionaries with 'chromosome', 'start', 'stop'). Default to None.
intersected_results_by_regions : Optional[BedTool]
A BedTool object containing combined intersected results for all regions. Default to None.
"""
bam_path: Path
bed_path: Path
logger: logging.Logger = field(
default_factory=lambda: logging.getLogger("GraTools")
)
sample_name: Optional[str] = None
sample_name_query: Optional[str] = None
chromosome_query: Optional[str] = None
start_query: Optional[int] = None
stop_query: Optional[int] = None
offset_first: int = 0
offset_last: int = 0
add_start_bases_first_segment: int = 0
intersect_bed: Optional[BedTool] = None
segment_id_set: Set[str] = field(default_factory=set, repr=False)
segment_id_first_query: Optional[str] = None
segment_id_first_strand: Optional[str] = None
segment_id_last_query: Optional[str] = None
segment_id_last_strand: Optional[str] = None
segment_id_first: Optional[str] = None # Tracks the first segment ID overall for this SubGraph instance
segment_id_last: Optional[str] = None # Tracks the last segment ID overall for this SubGraph instance
works_path: Optional[Path] = None
merge: Optional[int] = None
build_fasta_flag: Optional[bool] = False
gfa_walk_list: List[str] = field(default_factory=list, repr=False)
gfa_link_list: List[str] = field(default_factory=list, repr=False)
gfa_segment_list: List[str] = field(default_factory=list, repr=False)
dict_segments_samples: defaultdict = field(
default_factory=lambda: defaultdict(list), repr=False
)
dict_segments_sequence: defaultdict = field(
default_factory=lambda: defaultdict(str), repr=False # Ensure defaultdict has a type
)
sequences_list: List[SeqRecord] = field(default_factory=list, repr=False)
progress_dict: Optional[Dict] = None
task_id: Optional[TaskID] = None # Changed from int to TaskID for rich.progress type hint
regions: Optional[List[Dict[str, Any]]] = None
intersected_results_by_regions: Optional[BedTool] = None
def __post_init__(self) -> None:
"""Initialize sample_name from bed_path."""
suffixes = "".join(self.bed_path.suffixes) # e.g., ['.bed', '.gz'] for file.bed.gz
if ".bed.gz" in suffixes:
# For name.bed.gz, self.bed_path.name.replace(".bed.gz", "")
self.sample_name = self.bed_path.name.removesuffix(".bed.gz")
elif ".bed" in suffixes:
# For name.bed, self.bed_path.name.replace(".bed", "")
self.sample_name = self.bed_path.name.removesuffix(".bed")
else:
self.logger.error(f"The file {self.bed_path} is not a recognized BED file type.")
[docs]
def compute_intersection(self) -> None:
"""
Compute the intersection of BED regions for each region in self.regions.
Store the results in `self.intersected_results_by_regions`.
"""
all_bed = None
if not self.regions: # Default region if none provided
self.regions = [
{
"chromosome": self.chromosome_query,
"start": self.start_query,
"stop": self.stop_query,
}
]
if not isinstance(self.regions, list) or not all(
isinstance(region, dict) and
all(k in region for k in ["chromosome", "start", "stop"])
for region in self.regions
):
self.logger.error(
"Invalid regions provided. Must be a list of dictionaries, "
"each with 'chromosome', 'start', and 'stop' keys."
)
return # Exit if regions are invalid
# Load BED file
try:
all_bed = BedTool(self.bed_path.as_posix())
except Exception as e: # pybedtools can raise various exceptions
self.logger.error(f"Failed to load BED file '{self.bed_path}': {e}")
return # Exit if BED file cannot be loaded
# 3. Build a single BedTool for all regions, with index in the 4th column
lines = []
for idx, reg in enumerate(self.regions):
lines.append(f"{reg['chromosome']}\t{reg['start']}\t{reg['stop']}\t{idx}")
regions_bed = BedTool("\n".join(lines), from_string=True)
# 4. Single intersect call and saveas()
try:
intersect_all = all_bed.intersect(regions_bed, wa=True, wb=True)
tb = intersect_all.saveas() # creates a temporary file on disk
self.intersected_results_by_regions = tb
if intersect_all.count() <= 0:
# It's not an error for an intersection to be empty, but a warning/info might be useful.
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 7, "total": 7, # Mark as complete
"description": f"[red]'{self.sample_name}': No intersection found for region, aborted.",
}
self.logger.error(f"'{self.sample_name}': No intersection found for region {self.regions}")
except Exception as e:
self.logger.error(f"Error during single intersection: {e}")
return
[docs]
def build_walks(self) -> None:
"""Build GFA walks (W lines) and links (L lines) from the intersected BED regions."""
self.logger.info(f"'{self.sample_name}': Building subgraph GFA walks and links")
self.segment_id_set = set() # Reset for current build
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 5,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Building walks and links",
}
old_region = None
# defaultdict for walks: Key= (chrom_name, haplo_idx), Value= dict of fragments
# Fragment dict: Key= chrom_fragment_id, Value= data for this fragment
walk_dict = defaultdict(
lambda: defaultdict(
lambda: {
"seg_list": [],
"start": float("inf"), # Chromosomal start of the walk part
"stop": float("-inf"), # Chromosomal end of the walk part
}
)
)
# Iterate region by region
for intersect_region_bedtool in self.intersected_results_by_regions:
fields = intersect_region_bedtool
# fields = colonnes A + colonnes B (4 cols: chrom,start,stop,index)
region_idx = int(fields[-1])
if region_idx != old_region:
old_region = region_idx
region_start = None
chromosome_name = fields.chrom
segment_start_on_chr = fields.start # Start of this segment on the chromosome
segment_stop_on_chr = fields.stop # End of this segment on the chromosome
# GFA segment ID with orientation (e.g., +seg1, -seg2)
oriented_segment_id = fields[3]
# Fragment identifier on the chromosome (e.g., path fragment from original BED)
chromosome_fragment_id = fields[4]
haplotype_index = fields[5]
if region_start is None:
region_start = segment_start_on_chr
# Add segment ID (without orientation) to the set of all segments
self.segment_id_set.add(str(oriented_segment_id[1:]))
# Store first/last segment info if this is the query sample
if self.sample_name == self.sample_name_query:
if not self.segment_id_first_query: # First segment encountered for the query sample
self.segment_id_first_query = oriented_segment_id[1:]
self.segment_id_first_strand = oriented_segment_id[0]
self.segment_id_first = self.segment_id_first_query # Also set the general first segment
# Always update last segment for query sample to track the true end
self.segment_id_last_query = oriented_segment_id[1:]
self.segment_id_last_strand = oriented_segment_id[0]
self.segment_id_last = self.segment_id_last_query # Also set the general last segment
# Group segments by (chromosome, haplotype, fragment_id)
key = f"{chromosome_name}:{haplotype_index}:{region_idx}" # Walk identifier part
fragment_data = walk_dict[key][chromosome_fragment_id]
fragment_data["seg_list"].append(oriented_segment_id)
fragment_data["start"] = region_start
fragment_data["stop"] = segment_stop_on_chr
# Process collected walks
for chr_name_haplo_key, fragments_dict in walk_dict.items():
chr_name, haplotype_idx, _ = chr_name_haplo_key.split(":")
for fragment_id, data in fragments_dict.items():
# These are the actual genomic coordinates covered by this piece of the walk
walk_genomic_start = data["start"]
walk_genomic_stop = data["stop"]
ordered_segments = data["seg_list"] # Segments in order of appearance for this walk fragment
# Create GFA W line components
walk_line_parts = [
"W",
self.sample_name,
haplotype_idx,
chr_name,
str(walk_genomic_start),
str(walk_genomic_stop),
]
# Build L lines from these ordered segments
self._build_links(ordered_segments) # Pass the list of oriented segment IDs
# Create GFA path string (e.g., >seg1<seg2>seg3)
# The replace calls are for GFA spec compliance (> for forward, < for reverse)
gfa_path_string = "".join(s.replace("+", ">").replace("-", "<") for s in ordered_segments)
walk_line_parts.append(gfa_path_string)
# TODO: save to file for concat after on gratools
self.gfa_walk_list.append("\t".join(walk_line_parts))
def _build_links(self, segments_order: List[str]) -> None:
"""
Build GFA link lines (L lines) from an ordered list of oriented segment IDs.
Parameters
----------
segments_order : List[str]
Ordered list of GFA segment IDs, each prefixed with orientation
(e.g., ["+seg1", "-seg2", "+seg3"]).
"""
if len(segments_order) < 2: # Not enough segments to form a link
return
for i in range(len(segments_order) - 1):
seg1_oriented = segments_order[i]
seg2_oriented = segments_order[i + 1]
orient1_char = seg1_oriented[0] # '+' or '-'
seg1_id = seg1_oriented[1:]
orient2_char = seg2_oriented[0] # '+' or '-'
seg2_id = seg2_oriented[1:]
# Create L line: L <source_seg> <source_orient> <dest_seg> <dest_orient> <overlap_CIGAR>
# Assuming 0M overlap for simplicity, as is common in many GFA1.1 outputs.
txt_link = f"L\t{seg1_id}\t{orient1_char}\t{seg2_id}\t{orient2_char}\t0M"
self.gfa_link_list.append(txt_link)
[docs]
def build_segments(self) -> None:
"""Build GFA segments (S lines) from the BAM file for segments in `self.segment_id_set`."""
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 6,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Building segments",
}
self.logger.info(f"'{self.sample_name}': Building subgraph GFA segments")
if not self.segment_id_set:
self.logger.warning(f"'{self.sample_name}': No segment IDs for segment lines building. Skipping.")
self.gfa_segment_list = []
self.dict_segments_samples = defaultdict(list)
self.dict_segments_sequence = defaultdict(str)
return
bam_handler = GratoolsBam(bam_path=self.bam_path, logger=self.logger) # Pass logger
(
self.gfa_segment_list,
self.dict_segments_samples,
self.dict_segments_sequence,
) = bam_handler.build_segments(list_segments=self.segment_id_set)
[docs]
def filter_bed_with_awk(self) -> BedTool:
"""
Filter a BED file using an awk command line to extract only
lines containing an ID of interest.
"""
self.logger.info(f"'{self.sample_name}': Starting awk filtering for {len(self.segment_id_set)} IDs.")
segment_id_file = self.works_path / f"{self.sample_name}_segment_ids.txt"
filtered_bed_path = self.works_path / f"{self.sample_name}_filtered.bed"
# Save IDs to a file
with open(segment_id_file, "w") as f:
for seg_id in self.segment_id_set:
f.write(f"{seg_id}\n")
awk_script = f'''awk 'NR==FNR {{ids[$1]; next}} substr($4, 2) in ids' "{segment_id_file}" "{self.bed_path}" > "{filtered_bed_path}"'''
with open(filtered_bed_path, "w") as out_f:
subprocess.run(awk_script, shell=True, check=True, stdout=out_f)
self.logger.info(f"'{self.sample_name}': Filtered file with awk written to: {filtered_bed_path}")
return BedTool(filtered_bed_path), filtered_bed_path, segment_id_file
[docs]
def get_chr_pos(self, progress_dict: Optional[Dict], task_id: Optional[TaskID]) -> None:
"""
Identify chromosomal regions corresponding to segments in `self.segment_id_set`,
then compute intersections, and finally build walks and segments for these regions.
Parameters
----------
progress_dict : Optional[Dict]
A dictionary to track progress for multiprocessing.
task_id : Optional[TaskID]
The task ID for `rich.progress` tracking.
"""
self.progress_dict = progress_dict
self.task_id = task_id
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 2,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Searching corresponding chr pos",
}
filtered_bed_all, filtered_bed_path, segment_id_file = self.filter_bed_with_awk()
if not filtered_bed_all:
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 7, "total": 7,
"description": f"[red]'{self.sample_name}': No segment in BED.",
}
self.logger.warning(
f"'{self.sample_name}': No matching segment found in BED file '{self.bed_path}' "
f"for segment IDs: {list(self.segment_id_set)[:5]}..." # Log a few example IDs
)
segment_id_file.unlink(missing_ok=True)
filtered_bed_path.unlink(missing_ok=True)
return
self.progress_dict[self.task_id] = {
"progress": 3, "total": 7, # Mark as complete
"description": f"'{self.sample_name}': BED filtering successful, proceeding to merge..",
}
if self.works_path:
file_output_all = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_all.csv"
file_output_collapse = self.works_path / f"{self.bam_path.stem}_{self.sample_name}_regions_collapsed_d{self.merge}.csv"
# Convert to DataFrame after merging (without -d, i.e., only abutting features)
merged_all_bed = filtered_bed_all.merge(c=[4, 5, 6], o=["collapse", "distinct", "distinct"]) # bedtools merge columns: name, score, strand
merged_all_bed.saveas(file_output_all)
merged_collapse_bed = merged_all_bed.merge(d=self.merge, c=[4, 5, 6], o=["collapse", "distinct", "distinct"])
merged_collapse_bed.saveas(file_output_collapse)
# Define regions for subsequent intersection based on the collapsed dataframe
self.logger.info(f"'{self.sample_name}': compute regions base on merge dataframe")
self.regions = [
{"chromosome": feature.chrom, "start": feature.start, "stop": feature.stop}
for feature in merged_collapse_bed
]
if not self.regions:
self.logger.warning(f"'{self.sample_name}': No regions found for processing.")
self.progress_dict[self.task_id] = {
"progress": 7,
"total": 7,
"description": f"[red]'{self.sample_name}':No regions found for processing",
}
self.logger.debug(
f"'{self.sample_name}': Found {len(self.regions)} regions for processing. Example: {self.regions[0] if self.regions else 'None'}."
)
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 4,
"total": 7,
"description": f"[cyan]'{self.sample_name}': Found corresponding chr pos",
}
if Path(filtered_bed_path).exists():
filtered_bed_path.unlink()
if Path(segment_id_file).exists():
segment_id_file.unlink()
# Process regions: compute intersection, build walks, build segments
self.compute_intersection()
self.build_walks() # build_walks populates self.segment_id_set
self.build_segments() # build_segments uses self.segment_id_set
self.intersected_results_by_regions = None
# Log completion
if self.progress_dict and self.task_id is not None:
self.progress_dict[self.task_id] = {
"progress": 7,
"total": 7,
"description": f"[green]'{self.sample_name}': End processing",
}
[docs]
def build_fasta(self) -> None:
"""
Build FASTA sequences from the GFA walks (`self.gfa_walk_list`).
This method recovers FASTA sequences by processing each walk,
extracting the constitutive segments, and applying specific offsets
if the current sample is the query sample (`self.sample_name == self.sample_name_query`).
The offsets trim segment sequences at the beginning or end of the walk
to match the precise query region boundaries (`self.start_query`, `self.stop_query`).
The resulting sequences are stored as `SeqRecord` objects in `self.sequences_list`.
"""
num_walks = len(self.gfa_walk_list)
self.logger.info(f"'{self.sample_name}': Building FASTA for {num_walks} walks.")
if not self.dict_segments_sequence:
self.logger.warning(f"'{self.sample_name}': Segment sequences dictionary is empty. Cannot build FASTA.")
return
if not self.gfa_walk_list:
self.logger.info(f"'{self.sample_name}': No GFA walks to process for FASTA generation.")
return
newly_created_records: List[SeqRecord] = [] # Store records for this call
for walk_line in self.gfa_walk_list:
try:
(
_w_char, # 'W' character
current_walk_sample_name, # Sample name from the W line
haplotype_index,
chromosome_name,
walk_genomic_start_str,
walk_genomic_stop_str,
gfa_path_specification, # e.g., ">seg1<seg2" or "seg1+seg2-"
) = walk_line.split("\t")
except ValueError as e:
self.logger.error(
f"'{self.sample_name}': Malformed walk line during FASTA build: '{walk_line}'. Error: {e}",
exc_info=True)
continue
walk_genomic_start = int(walk_genomic_start_str)
walk_genomic_stop = int(walk_genomic_stop_str)
# Initialize effective coordinates for this walk's sequence.
# These will be adjusted by offsets for the query sample.
effective_walk_chr_start = walk_genomic_start
effective_walk_chr_stop = walk_genomic_stop
# Calculate offsets if this is the query sample.
# These offsets represent how much to trim from the start/end of the *entire walk's sequence*
# if the walk extends beyond the query region.
if self.sample_name == self.sample_name_query:
if self.start_query is not None and self.stop_query is not None: # Ensure query bounds are set
self.offset_first = self.start_query - walk_genomic_start
self.offset_last = walk_genomic_stop - self.stop_query
else:
self.logger.warning(
f"'{self.sample_name}': Query sample set, but start/stop_query not fully set. Offsets may be incorrect.")
# Parse GFA path string (e.g., ">s1<s2" or "s1+s2-") into oriented segments
# GFA1 W line path: >s1>s2<s3... or s1+s2-s3+...
oriented_segments_in_path: List[Tuple[str, str]] = [] # List of (orientation_char, segment_id)
if ">" in gfa_path_specification[0] or "<" in gfa_path_specification[0]:
# Format like >s1<s2
matches = re.findall(r"([<>])([^<>]+)", gfa_path_specification)
for orient_map_char, seg_id_str in matches:
oriented_segments_in_path.append(("+" if orient_map_char == ">" else "-", seg_id_str))
else:
# Format like s1+s2- (less common in W lines but GFA spec allows segment names with +/-)
matches = re.findall(r"(.+?)([+-])", gfa_path_specification)
for seg_id_str, orient_char in matches:
oriented_segments_in_path.append((orient_char, seg_id_str))
if not oriented_segments_in_path:
self.logger.warning(
f"'{self.sample_name}': Could not parse path specification '{gfa_path_specification}' in walk line: '{walk_line}'. Skipping this walk for FASTA building.")
continue
num_segments_in_walk = len(oriented_segments_in_path)
# List to hold Bio.Seq fragments for constructing the full sequence of this walk
current_walk_seq_fragments: List[str] = []
self.logger.debug(
f"'{self.sample_name}': Processing walk on {chromosome_name} ({walk_genomic_start}-{walk_genomic_stop}) "
f"with {num_segments_in_walk} segments for FASTA."
)
for segment_idx, (segment_orient_char, segment_id) in enumerate(oriented_segments_in_path):
if segment_id not in self.dict_segments_sequence:
self.logger.warning(
f"'{self.sample_name}': Sequence for segment '{segment_id}' in walk not found. Skipping this segment.")
# Potentially add a placeholder like 'N's if a continuous sequence is critical,
# or simply skip, which might create a shorter/disjointed sequence.
continue
current_segment_seq = self.dict_segments_sequence[segment_id]
if segment_orient_char == "-":
current_segment_seq = reverse_complement_string(current_segment_seq)
# Apply complex offsetting logic from the original code IF this is the query sample
if self.sample_name == self.sample_name_query:
# These checks need self.segment_id_first_query, self.segment_id_first_strand, etc.
# to be correctly set (usually when the query SubGraph itself was processed).
# And self.offset_first / self.offset_last calculated above.
# Check if current segment is the designated "first query segment"
if segment_id == self.segment_id_first_query:
if segment_idx == 0: # Is it the *first segment of this walk*?
# And does its orientation match the expected start orientation?
if segment_orient_char == self.segment_id_first_strand:
if self.offset_first > 0: # Positive offset means query starts *after* walk's genomic start
if self.offset_first < len(current_segment_seq):
current_segment_seq = current_segment_seq[self.offset_first:]
effective_walk_chr_start = walk_genomic_start + self.offset_first
else: # Offset is >= segment length, segment is fully trimmed
self.logger.debug(
f"Offset_first ({self.offset_first}bp) consumes entire first segment '{segment_id}' ({len(current_segment_seq)}bp).")
current_segment_seq = ""
# else: offset_first <= 0, means query starts at or before walk start, no trim from beginning needed.
elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*?
# This case is for when the *first query segment* also happens to be the *last segment of the walk*,
# and its orientation is opposite to the expected start orientation.
# This is a very specific scenario, implying the walk might be very short or circular in a way.
if segment_orient_char != self.segment_id_first_strand:
# Here, offset_first is used to trim from the *end* of the segment.
# This implies offset_first was calculated based on the query *start*, but is applied to the *end*
# of this segment if it's the first_query_segment but found at the end of the walk in reverse.
if self.offset_first > 0: # If an offset from the start of the query was calculated
if self.offset_first < len(current_segment_seq):
current_segment_seq = current_segment_seq[
:-self.offset_first] # Trim from end
effective_walk_chr_stop = walk_genomic_stop - self.offset_first
else:
current_segment_seq = ""
# Check if current segment is the designated "last query segment"
if segment_id == self.segment_id_last_query:
if segment_idx == 0: # Is it the *first segment of this walk*?
# This case is for when the *last query segment* is also the *first segment of the walk*,
# and its orientation is opposite to the expected end orientation.
if segment_orient_char != self.segment_id_last_strand:
# Here, offset_last is used to trim from the *beginning* of the segment.
# offset_last was calculated based on query *end*.
if self.offset_last > 0: # Positive offset_last means query ends *before* walk's genomic end
if self.offset_last < len(current_segment_seq):
current_segment_seq = current_segment_seq[
self.offset_last:] # Trim from beginning
effective_walk_chr_start = walk_genomic_start + self.offset_last
else:
current_segment_seq = ""
elif segment_idx == num_segments_in_walk - 1: # Is it the *last segment of this walk*?
# And does its orientation match the expected end orientation?
if segment_orient_char == self.segment_id_last_strand:
if self.offset_last > 0:
if self.offset_last < len(current_segment_seq):
current_segment_seq = current_segment_seq[:-self.offset_last]
effective_walk_chr_stop = walk_genomic_stop - self.offset_last
else:
self.logger.error(
f"Offset_last ({self.offset_last}bp) consumes entire last segment '{segment_id}' ({len(current_segment_seq)}bp).")
current_segment_seq = ""
if len(current_segment_seq) > 0:
current_walk_seq_fragments.append(current_segment_seq)
# After processing all segments for this walk
if not current_walk_seq_fragments:
self.logger.debug(
f"'{self.sample_name}': No sequence fragment generated for walk on {chromosome_name} "
f"(ID: {current_walk_sample_name}-{haplotype_index}). Skipping FASTA record for this walk.")
continue
final_walk_sequence = Seq("".join(current_walk_seq_fragments))
if len(final_walk_sequence) == 0:
self.logger.debug(
f"'{self.sample_name}': Resulting sequence for walk on {chromosome_name} is empty. Skipping FASTA.")
continue
# Construct FASTA record ID and description
fasta_seq_id: str
fasta_description: str
if self.sample_name == self.sample_name_query:
# For the query sample, use precise query coordinates in the ID if available
# Ensure query parameters are fully defined for this formatting
q_sample = self.sample_name_query or "query_sample"
q_chrom = self.chromosome_query or chromosome_name # Fallback to walk's chrom if query_chrom not set
# Use effective_walk_chr_start/stop which were adjusted by offsets
fasta_seq_id = f"{q_sample}-{q_chrom}:{effective_walk_chr_start}-{effective_walk_chr_stop}"
fasta_description = f""
else:
# For other (target/non-query) samples
fasta_seq_id = f"{current_walk_sample_name}-{chromosome_name}:{effective_walk_chr_start}-{effective_walk_chr_stop}"
# Description includes length of this sample's sequence and the query region it corresponds to
seq_actual_len = effective_walk_chr_stop - effective_walk_chr_start
query_region_str = "N/A"
if self.sample_name_query and self.chromosome_query and \
self.start_query is not None and self.stop_query is not None:
query_region_str = f"{self.sample_name_query}-{self.chromosome_query}:{self.start_query}-{self.stop_query}"
fasta_description = f"SEQ_LEN={seq_actual_len}|QUERY={query_region_str}"
# Ensure ID is suitable for FASTA (e.g., no problematic spaces at end)
fasta_seq_id = fasta_seq_id.strip()
seq_record = SeqRecord(
final_walk_sequence,
id=fasta_seq_id,
name=fasta_seq_id, # Often same as id
description=fasta_description
)
newly_created_records.append(seq_record)
# After processing all walks
self.sequences_list.extend(newly_created_records) # Add all new records to the main list
self.logger.info(
f"'{self.sample_name}': Finished FASTA building. Added {len(newly_created_records)} new records. "
f"Total records in sequences_list: {len(self.sequences_list)}.")
[docs]
class AsyncGfaDatabase:
"""
Manage an asynchronous SQLite database for storing and querying GFA link data.
It uses `aiosqlite` for non-blocking operations within an asyncio event loop
and serializes writes via an internal FIFO queue to prevent SQLite lock contention.
Attributes
----------
db_file : Path
Path to the SQLite database file.
timeout : float
Maximum timeout (in seconds) for SQLite lock acquisition.
logger : logging.Logger
Logger instance.
_conn : Optional[aiosqlite.Connection]
Shared SQLite connection (or None if not connected).
_write_queue : asyncio.Queue
Asynchronous queue for batches of links to be inserted. Max size 100.
_sql_task : Optional[asyncio.Task]
Background task consuming the queue and writing to the database.
_shutdown : bool
Flag to signal shutdown to the writer task.
"""
[docs]
def __init__(self, db_file: Path, timeout: float = 30.0):
"""
Initialize the instance without opening the connection.
Scheme is created upon the first call to `connect()`.
Args:
db_file (Path): Path to the SQLite file to be used as backend.
timeout (float, optional): Maximum wait time for SQLite locks (default 30.0s).
"""
self.db_file: Path = db_file
self.timeout: float = timeout # In seconds
self._shutdown: bool = False
self._conn: Optional[aiosqlite.Connection] = None
self._write_queue: asyncio.Queue[Optional[List[Tuple[str, int, int, str, int, int]]]] = asyncio.Queue(
maxsize=1000)
self._sql_task: Optional[asyncio.Task] = None
self.logger = logging.getLogger("GraTools") # More specific logger name
[docs]
async def connect(self) -> None:
"""
Connect to the SQLite db (if not already connected), configure PRAGMA settings,
create the 'links' table schema (if it doesn't exist), and start the SQL writer task.
This method is idempotent: if already connected, it does nothing.
"""
if self._conn:
return
self.logger.debug(f"Connecting to the SQLite database: {self.db_file.name}")
# Open the single connection
self._conn = await aiosqlite.connect(
self.db_file.as_posix(), timeout=self.timeout # timeout here is for connection
)
# Performance optimizations and configuration
await self._conn.execute("PRAGMA synchronous = OFF;") # Faster, but higher risk of corruption on crash
await self._conn.execute("PRAGMA journal_mode = WAL;") # Write-Ahead Logging for better concurrency
await self._conn.execute(
f"PRAGMA busy_timeout = {int(self.timeout * 1000)};") # busy_timeout is in milliseconds
# Create schema
await self._conn.execute("""
CREATE TABLE IF NOT EXISTS links (
seg_id_1 TEXT,
orient_seg_1 INTEGER, -- +1 for '+', -1 for '-'
orient_key_seg_1 INTEGER, -- Often same as orient_seg_1, purpose might be specific
seg_id_2 TEXT,
orient_seg_2 INTEGER,
orient_key_seg_2 INTEGER
);
""")
await self._conn.commit()
# Start the background insertion task
if self._sql_task is None or self._sql_task.done():
self._sql_task = asyncio.create_task(self._sql_writer())
self.logger.debug("AsyncGfaDatabase connected and writer task started/confirmed.")
async def _sql_writer(self) -> None:
"""
Background task that continuously fetches link batches from the queue and inserts them into the database.
It handles shutdown signals and manages SQLite transactions for batch inserts.
"""
assert self._conn is not None, "Connection must be initialized before writer starts"
self.logger.debug("SQL writer task started.")
processed_data_batches = 0 # Counter for actual data batches
while True: # The exit condition is handled by internal breaks
batch = None
item_was_retrieved = False
try:
# Wait for an item from the queue with a timeout
# The timeout allows periodic checking of _shutdown
batch = await asyncio.wait_for(self._write_queue.get(), timeout=0.5)
item_was_retrieved = True
if batch is None: # Sentinel to stop the writer
self.logger.debug("SQL writer received end. Preparing to stop.")
break # Exit the main while loop
if not batch: # An empty batch (empty list), not the sentinel
# task_done will be called in finally
continue # Go to the next iteration to get another batch
# This is a valid data batch
processed_data_batches += 1
self.logger.debug(
f"SQL writer processing data batch {processed_data_batches} of size {len(batch)}. "
f"Queue size: {self._write_queue.qsize()}"
)
# Database insertion
await self._conn.executemany(
"""INSERT INTO links (
seg_id_1, orient_seg_1, orient_key_seg_1,
seg_id_2, orient_seg_2, orient_key_seg_2
) VALUES (?, ?, ?, ?, ?, ?);""",
batch,
)
await self._conn.commit()
# self.logger.debug(f"SQL writer successfully committed data batch {processed_data_batches}.")
except asyncio.TimeoutError:
# Timeout while waiting for an item from the queue
if self._shutdown and self._write_queue.empty():
self.logger.info("SQL writer: Timeout on get() with _shutdown True and queue empty. Stopping.")
break # Exit the main while loop
continue
except sqlite3.Error as e:
self.logger.error(
f"SQLite error during processing of batch (approx. #{processed_data_batches + 1}): {e}",
exc_info=True)
if self._conn.in_transaction:
try:
await self._conn.rollback()
self.logger.warning("SQL writer rolled back transaction due to SQLite error.")
except sqlite3.Error as rb_err:
self.logger.error(f"Error during SQLite rollback: {rb_err}")
# Continue to try and process other batches, the current batch is lost without re-queuing.
except Exception as e:
self.logger.error(f"Unexpected error in SQL writer (approx. batch #{processed_data_batches + 1}): {e}",
exc_info=True)
# Handle transaction if necessary, as for sqlite3.Error
finally:
if item_was_retrieved:
self._write_queue.task_done()
# if batch is not None and batch: # If it was an actual data batch
# self.logger.debug(f"SQL writer marked data batch {processed_data_batches} as done.")
# No special log for an empty batch after task_done
self.logger.debug("SQL writer task has stopped.")
[docs]
async def batch_insert_links(
self, links: List[Tuple[str, int, int, str, int, int]]
) -> None:
"""
Enqueue a batch of links for non-blocking insertion.
Must be called after `await connect()`.
Args:
links : List of tuples, each representing a link:
(seg_id_1, orient_seg_1, orient_key_seg_1,
seg_id_2, orient_seg_2, orient_key_seg_2).
"""
if not self._conn or (self._sql_task and self._sql_task.done()):
await self.connect() # Will also restart task if needed and done
if not links: # Do not put empty list on queue if it's not a sentinel
return
self.logger.debug(
f"SQL writer adding batch of {len(links)} links to write queue. Current queue size: {self._write_queue.qsize()}"
)
await self._write_queue.put(links)
[docs]
async def create_indexes(self) -> None:
"""
Create indexes on seg_id_1 and seg_id_2 to accelerate queries.
Should ideally be called after all data insertions are complete.
"""
if not self._conn:
await self.connect()
self.logger.info("Creating indexes on links table.")
try:
await self._conn.execute(
"CREATE INDEX IF NOT EXISTS idx_seg_id_1 ON links(seg_id_1);"
)
await self._conn.execute(
"CREATE INDEX IF NOT EXISTS idx_seg_id_2 ON links(seg_id_2);"
)
await self._conn.commit()
except sqlite3.Error as e:
self.logger.error(f"Failed to create indexes: {e}", exc_info=True)
[docs]
async def query_links_by_segment(self, segment_id: str) -> List[Tuple[Any, ...]]:
"""
Retrieves all links where `segment_id` appears as seg_id_1 or seg_id_2.
Args:
segment_id : The ID of the target segment.
Returns:
List of tuples, each representing a full link row from the database.
"""
if not self._conn:
await self.connect()
self.logger.debug(f"Querying links for segment: {segment_id}")
try:
async with self._conn.execute( # `async with cursor` is good practice
"SELECT * FROM links WHERE seg_id_1 = ? OR seg_id_2 = ?;",
(segment_id, segment_id),
) as cursor:
results = await cursor.fetchall()
self.logger.debug(f"Found {len(results)} links for segment {segment_id}.")
return results
except sqlite3.Error as e:
self.logger.error(f"Error querying links for segment {segment_id}: {e}", exc_info=True)
return []
[docs]
async def test_query_links(
self, segment_id: str
) -> List[Tuple[str, str, int, int]]:
"""
Retrieve and categorize links related to a given segment.
- "before": links where `segment_id` is seg_id_2.
- "after": links where `segment_id` is seg_id_1.
Args:
segment_id : The segment to analyze.
Return:
List of tuples: (connected_segment_id, position_type, orient_seg_1, orient_seg_2).
"""
if not self._conn:
await self.connect()
self.logger.debug(f"Testing query links for segment: {segment_id}")
before_rows = []
after_rows = []
try:
# "before" links (segment_id is the destination)
async with self._conn.execute(
"SELECT seg_id_1, orient_seg_1, orient_seg_2 FROM links WHERE seg_id_2 = ?;",
(segment_id,),
) as cur_before:
before_rows = await cur_before.fetchall()
self.logger.debug(
f"Found {len(before_rows)} 'before' links for {segment_id}."
)
# "after" links (segment_id is the source)
async with self._conn.execute(
"SELECT seg_id_2, orient_seg_1, orient_seg_2 FROM links WHERE seg_id_1 = ?;",
(segment_id,),
) as cur_after:
after_rows = await cur_after.fetchall()
self.logger.debug(
f"Found {len(after_rows)} 'after' links for {segment_id}."
)
except sqlite3.Error as e:
self.logger.error(f"Error during test_query_links for {segment_id}: {e}", exc_info=True)
# Return whatever was fetched before the error, or empty
# This part of error handling depends on desired behavior.
# Merge and sort results
# The columns in `before_rows` are (seg_id_1, orient_seg_1, orient_seg_2)
# The columns in `after_rows` are (seg_id_2, orient_seg_1, orient_seg_2)
# The resulting tuple should be (connected_segment_id, "before"/"after", orient_of_link_seg1, orient_of_link_seg2)
merged = natsorted([(seg_id, "before", o1, o2) for seg_id, o1, o2 in before_rows] + \
[(seg_id, "after", o1, o2) for seg_id, o1, o2 in after_rows])
self.logger.debug(f"Total merged links for {segment_id}: {len(merged)}")
return merged
[docs]
async def find_children_and_grandchildren(
self, node_id: str
) -> Dict[str, List[str]]:
"""
Find direct successors (children) and second-degree successors (grandchildren) of a segment.
A child is `seg_id_2` where `node_id` is `seg_id_1`.
A grandchild is a child of a child.
Args:
node_id : The starting segment ID.
Returns:
A dictionary: {"children": [IDs], "grandchildren": [IDs]}.
"""
if not self._conn:
await self.connect()
self.logger.debug(f"Finding children and grandchildren for node: {node_id}")
children: List[str] = []
grandchildren: List[str] = []
try:
# Direct children
async with self._conn.execute(
"SELECT seg_id_2 FROM links WHERE seg_id_1 = ?;", (node_id,)
) as cur_children:
children = [r[0] for r in await cur_children.fetchall()] # r is a tuple e.g. ('seg_X',)
self.logger.debug(f"Found {len(children)} children for {node_id}.")
# Grandchildren (children of children)
if children:
# Create placeholders for the IN clause: (?, ?, ...)
placeholders = ",".join("?" for _ in children)
sql_grandchildren = f"SELECT seg_id_2 FROM links WHERE seg_id_1 IN ({placeholders});"
async with self._conn.execute(sql_grandchildren, children) as cur_gc:
grandchildren_tuples = await cur_gc.fetchall()
# Deduplicate grandchildren, as multiple children might lead to the same grandchild.
grandchildren = sorted(list(set(r[0] for r in grandchildren_tuples)))
self.logger.debug(
f"Found {len(grandchildren)} unique grandchildren for {node_id} (from {len(grandchildren_tuples)} raw)."
)
except sqlite3.Error as e:
self.logger.error(f"Error finding children/grandchildren for {node_id}: {e}", exc_info=True)
# Return lists as they are up to the point of error.
return {"children": children, "grandchildren": grandchildren}
[docs]
async def close(self) -> None:
"""
Properly shut down the database:
- Signal the SQL writer task to stop.
- Wait for the writer task to finish processing its queue (with timeout).
- Cancel the task if it doesn't finish in time.
- Create indexes (important to do this after all writes).
- Close the SQLite connection.
"""
self.logger.debug("Initiating AsyncGfaDatabase shutdown...")
self._shutdown = True # Signal writer to stop after draining queue
if self._write_queue:
self.logger.debug("Putting sentinel on SQL write queue.")
await self._write_queue.put(None) # Sentinel to stop the writer loop
self.logger.debug("Waiting for SQL queue to drain...")
try:
# Wait for all items in queue to be processed.
# Timeout for queue.join() should be generous if large batches or slow I/O expected.
await self._write_queue.join()
self.logger.debug("SQL queue drained successfully (all items processed).")
except asyncio.TimeoutError:
self.logger.warning("Timeout waiting for SQL queue to drain during close.")
except Exception as e:
self.logger.error(f"Error waiting for SQL queue to drain: {e}", exc_info=True)
if self._sql_task and not self._sql_task.done():
self.logger.debug("SQL writer task still running, attempting to cancel.")
self._sql_task.cancel()
try:
await self._sql_task # Wait for task to acknowledge cancellation
self.logger.debug("SQL writer task cancelled successfully.")
except asyncio.CancelledError:
self.logger.info("SQL writer task was cancelled as expected during close.")
except Exception as e:
self.logger.error(f"Error awaiting SQL writer task cancellation: {e}", exc_info=True)
elif self._sql_task and self._sql_task.done():
self.logger.debug("SQL writer task was already done.")
else:
self.logger.debug("No SQL writer task to shut down or task was never started.")
# Create indexes after all data is written and writer task is stopped
if self._conn: # Only if connection was ever established
await self.create_indexes()
try:
await self._conn.close()
except Exception as e: # Catch potential errors during aiosqlite close
self.logger.error(
f"Error closing SQLite connection: {e}", exc_info=True
)
finally:
self._conn = None # Ensure _conn is None after attempting to close
[docs]
class AsyncBedWriter:
"""
Asynchronous BED file writer using `aiofiles`.
It buffers lines per sample and writes them in batches to separate BED files.
Attributes
----------
bed_dir : Path
Output directory for .bed files.
batch_size : int
Number of lines to buffer per sample before an automatic flush.
progress : Optional[Progress]
Rich Progress instance for displaying progress (if provided).
logger : logging.Logger
Logger instance.
_queue : asyncio.Queue[Tuple[str, List[str]]]
Internal queue for (sample_name, lines_to_write) tuples. Max size 100.
_shutdown : bool
Flag to signal shutdown to the writer loop.
_task : Optional[asyncio.Task]
The background asyncio task running the `_writer_loop`.
"""
def __init__(self, bed_dir: Path, batch_size: int = 1,
progress: Optional[Progress] = None):
self.bed_dir: Path = bed_dir
self.batch_size: int = batch_size
self._queue: asyncio.Queue[Tuple[str, List[str]]] = asyncio.Queue(maxsize=1000000)
self._shutdown: bool = False
self._task: Optional[asyncio.Task] = None
self.progress: Optional[Progress] = progress # Can be None
self.logger = logging.getLogger("GraTools")
[docs]
def start(self) -> None:
"""Start the writer loop as a background task if not already running."""
if self._task is None or self._task.done():
self._shutdown = False # Reset shutdown flag if restarting
self.logger.debug("Starting AsyncBedWriter loop task.")
self._task = asyncio.create_task(self._writer_loop())
self.logger.debug("AsyncBedWriter task launched.")
else:
self.logger.debug("AsyncBedWriter task is already running.")
[docs]
async def enqueue(self, sample: str, lines: List[str]) -> None:
"""
Add a sample and its corresponding lines to the write queue.
Use `put_nowait` assuming the queue rarely fills; consider `await self._queue.put()`
if backpressure to the producer is acceptable when the queue is full.
Args:
sample (str): The sample name (used for filename).
lines (List[str]): A list of strings (lines) to write to the BED file.
Each string should include its own newline character if needed.
"""
if not lines:
self.logger.debug(f"AsyncBedWriter: enqueue called with empty lines for sample '{sample}'. Skipping.")
return
try:
self.logger.debug(
f"BED writer enqueuing {len(lines)} lines for sample '{sample}'. Current queue size: {self._queue.qsize()}"
)
self._queue.put_nowait((sample, lines)) # Can raise asyncio.QueueFull if maxsize is hit
except asyncio.QueueFull:
self.logger.debug(
f"AsyncBedWriter queue is full. Failed to enqueue {len(lines)} lines for sample '{sample}'. "
"Consider increasing queue maxsize or slowing down producers."
)
# OPTIONAL: Implement retry logic or a different strategy for full queue.
[docs]
async def enqueue_single_line(self, sample: str, line: str):
"""
Asynchronously adds a single line to the write queue.
This will wait if the queue is full, providing backpressure.
"""
if not line:
return
try:
# Utilisez `await self._queue.put` pour attendre si la file est pleine.
# C'est la bonne faรงon de gรฉrer la contre-pression (backpressure).
await self._queue.put((sample, line))
except asyncio.CancelledError:
# Si la tรขche qui appelle cette mรฉthode est annulรฉe, on propage.
self.logger.warning(f"Enqueue operation for sample '{sample}' was cancelled.")
raise
async def _writer_loop(self) -> None:
"""
Consume the queue and writes lines to sample-specific BED files in batches.
Manage buffers for each sample.
"""
self.logger.debug("BED writer loop started.")
buffers: Dict[str, List[str]] = defaultdict(list)
try:
self.bed_dir.mkdir(parents=True, exist_ok=True)
except OSError as e:
self.logger.error(f"Failed to create BED output directory {self.bed_dir}: {e}. Writer loop cannot proceed.")
return
while not self._shutdown or not self._queue.empty():
sample_name_for_log = ""
try:
# Obtenir une seule ligne (item)
current_sample, new_line = await asyncio.wait_for(self._queue.get(), timeout=0.5)
buffers[current_sample].extend(new_line)
# Flush buffer for this sample if it's full
if len(buffers[current_sample]) >= self.batch_size:
self.logger.debug(
f"BED writer flushing BED buffer for sample '{current_sample}' (size: {len(buffers[current_sample])})."
)
await self._flush_sample(current_sample,
buffers[current_sample]) # _flush_sample clears the buffer passed
# buffers[current_sample] is now empty
buffers[current_sample] = list()
except asyncio.TimeoutError:
# No item from queue within timeout, check shutdown status
if self._shutdown and self._queue.empty():
break
# Continue polling if not shutting down or queue not empty
continue
except asyncio.CancelledError:
self.logger.info("BED writer loop was cancelled.")
# Perform a final flush before exiting due to cancellation
await self._flush_all(buffers)
raise # Re-raise CancelledError to signal cancellation to caller
except Exception as e:
self.logger.error(f"Error in BED writer loop processing item for '{sample_name_for_log}': {e}",
exc_info=True)
finally:
# This check is crucial: task_done should only be called if queue.get() succeeded.
# If TimeoutError or CancelledError occurred in wait_for, get() didn't return.
if 'current_sample' in locals(): # Indicates get() succeeded
self._queue.task_done()
# self.logger.debug(f"BED writer marked item for sample '{current_sample}' as done.")
del current_sample # Clean up to avoid using stale var if next iteration errors early
# Periodically flush all buffers if shutdown is requested (helps drain faster)
if self._shutdown and not self._queue.empty():
await self._flush_all(buffers) # Flush all to expedite if shutting down
# Final flush for any remaining data in buffers after loop finishes
self.logger.debug("BED writer loop finished. Performing final flush of all buffers.")
await self._flush_all(buffers)
self.logger.debug("BED writer loop has stopped.")
async def _flush_sample(self, sample: str, buffer_lines: List[str]) -> None:
"""Write the content of `buffer_lines` to the BED file for `sample` and clear `buffer_lines`."""
if not buffer_lines:
self.logger.debug(f"No lines to flush for sample '{sample}'.")
return
path = self.bed_dir / f"{sample}.bed"
# Parent directory creation is handled in _writer_loop start now.
try:
async with aiofiles.open(path, mode="a", encoding="utf-8") as f:
# Assuming lines in buffer_lines already have newline characters.
# If not, add them: await f.write("".join(line + "\n" for line in buffer_lines))
await f.writelines(buffer_lines)
# self.logger.debug(f"Successfully flushed {len(buffer_lines)} lines to '{path.name}'.")
buffer_lines.clear() # Clear the buffer that was passed, after successful write
except IOError as e: # More specific exception for file operations
self.logger.error(
f"IOError flushing sample '{sample}' to '{path}': {e}", exc_info=True
)
except Exception as e:
self.logger.error(
f"Unexpected error flushing sample '{sample}' to '{path}': {e}", exc_info=True
)
async def _flush_all(self, buffers: Dict[str, List[str]]) -> None:
"""Write all non-empty buffers to their respective files."""
self.logger.debug(f"Performing flush of all remaining buffers ({len(buffers)} total).")
# Iterate over a copy of items if modifying dict, but here we pass buffer to _flush_sample which clears it.
for sample_name, buf_list in list(buffers.items()): # list() for safe iteration if _flush_sample modified dict
if buf_list: # Only flush if buffer has content
await self._flush_sample(sample_name, buf_list)
if not buf_list and sample_name in buffers: # If buffer became empty (cleared by _flush_sample)
del buffers[sample_name] # Clean up empty entries from the main buffers dict
self.logger.debug("Flush of all buffers complete.")
[docs]
async def shutdown(self) -> None:
"""
Signal the writer loop to shut down and wait for it to complete.
Ensure all pending data is flushed.
"""
self._shutdown = True
if self._task is not None and not self._task.done():
self.logger.debug("Waiting for BED writer task to complete...")
try:
# The writer loop itself handles queue draining on shutdown signal.
# We just need to wait for the task to finish.
await asyncio.wait_for(self._task, timeout=None)
self.logger.debug("BED writer task completed successfully after shutdown signal.")
except asyncio.TimeoutError:
self.logger.warning(
"Timeout waiting for BED writer task to complete during shutdown. Attempting cancellation.")
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
self.logger.info("BED writer task was cancelled during shutdown due to timeout.")
except Exception as e_inner:
self.logger.error(f"Exception after cancelling BED writer task: {e_inner}", exc_info=True)
except asyncio.CancelledError: # If shutdown itself was called from a cancelled context
self.logger.info("BED writer task was externally cancelled during its shutdown process.")
# May need to re-ensure final flush if task was hard-cancelled.
except Exception as e:
self.logger.error(
f"Exception while waiting for BED writer task during shutdown: {e}", exc_info=True
)
elif self._task and self._task.done():
self.logger.debug("BED writer task was already done prior to shutdown call.")
else:
self.logger.debug(
"No BED writer task to shut down (was never started or already None)."
)
[docs]
@dataclass
class GFA:
"""
Manage parsing of a GFA (Graphical Fragment Assembly) file, compute statistics,
and generate related files such as BAM-containing segments files and BED files for path per sample.
Fill an asynchronous database for links and asynchronous BED writer.
Attributes
----------
gfa_path : Path
Path to the input GFA file (can be .gfa or .gfa.gz).
threads : int, optional
Number of threads for operations like BAM file processing. Default is 1.
logger : logging.Logger
Logger object. Default is a logger named "GraTools".
gfa_name : Optional[str]
Name of the GFA file derived from `gfa_path` (without extensions). Auto-initialized.
version : Optional[str]
GFA version extracted from the header (e.g., "1.0"). Auto-initialized.
header_gfa : List[str]
List of header lines (H lines) from the GFA file. Auto-initialized.
sample_reference : Optional[str]
Reference sample name, potentially from GFA header (RS tag). Auto-initialized.
bam_segments_file : Optional[Path]
Path to the BAM file where segments (S lines) will be written. Auto-initialized.
dict_samples_chrom : defaultdict[str, OrderedDict[str, List[str]]]
Map sample names to an OrderedDict of chromosome names,
which in turn map to a list of "start\\tstop" fragment strings derived from Walk (W) lines. Auto-initialized.
dict_segments_size : defaultdict[str, int]
Map segment IDs and their length (in base pairs). Auto-initialized.
dict_segments_samples : defaultdict[str, List[str]]
Map segment IDs to a list of sample identifiers ("sample;chromosome;haplotype")
that contain the segment. Auto-initialized.
dict_samples_bed : defaultdict[str, OrderedDict[str, Path]]
(Note: This attribute is not directly populated by the current parsing logic.
It was likely intended to track the paths of generated BED files.
The `AsyncBedWriter` internally manages these paths. This attribute might be
redundant or used for post-processing tracking). Auto-initialized.
works_path : Optional[Path]
Path to the working directory (e.g., ".../{gfa_name}_GraTools-IMPORT"). Auto-initialized.
bed_path : Optional[Path]
Path to the subdirectory for BED files within `works_path`. Auto-initialized.
bam_path : Optional[Path]
Path to the subdirectory for BAM files within `works_path`. Auto-initialized.
found_minigraph : bool
Flag indicating if a sample named 'MINIGRAPH' (case-insensitive) was found in Walk lines. Defaults to False.
import_links : bool, optional
If True, GFA links (L lines) are stored in an SQLite database. Defaults to True.
db_links : Optional[AsyncGfaDatabase]
Asynchronous database handler for GFA links. Auto-initialized if `import_links` is True.
segment_count : int
Total number of segments (S lines) processed. Defaults to 0.
total_segment_length : int
Sum of lengths of all segments. Defaults to 0.
link_count : int
Total number of links (L lines) processed. Defaults to 0.
degrees : defaultdict[str, int]
Maps segment IDs to their degree (number of links connected). Defaults to an empty defaultdict.
walks_count : int
Total number of walks (W lines) processed. Defaults to 0.
max_walk_rank : int
Maximum number of segments in any single walk. Defaults to 0.
sum_rank0_length : int
Sum of lengths of the first segments of all walks. Defaults to 0.
input_genome_size : int
Cumulative size of all paths (sum of segment lengths along each walk). Defaults to 0.
walks_info : List[Dict[str, Any]]
List of dictionaries, each containing info for a walk (Path name, Sequence length, Num Segments). Defaults to an empty list.
inverted_links_count : int
Count of links where orientations differ (e.g., S1+ -> S2-). Defaults to 0.
negative_links_count : int
Count of links where both segments have negative orientation (S1- -> S2-). Defaults to 0.
self_links_count : int
Count of links where a segment links to itself (S1 -> S1). Defaults to 0.
isolated_segments : Set[str]
Set of segment IDs that have no links connected to them. Initialized with all segments, then linked ones removed. Defaults to an empty set.
shared_executor : Optional[ThreadPoolExecutor]
Executor for running synchronous tasks in threads. Auto-initialized.
progress : Optional[Progress]
Rich Progress instance for displaying progress. Auto-initialized.
line_type_counts : Counter
Counts of each GFA line type (H, S, L, W, P, C, E, U). Auto-initialized.
header_gfa_file : Optional[Path]
Path to where the GFA header is saved. Auto-initialized.
stats_file : Optional[Path]
Path to where GFA statistics are saved. Auto-initialized.
db_file_path : Optional[Path]
Path to the SQLite database file for links. Auto-initialized.
RE_ORIENTED_SEG_GT_LT : re.Pattern
Compiled regular expression for parsing oriented segments using '>' and '<'. Auto-initialized.
RE_ORIENTED_SEG_PLUS_MINUS : re.Pattern
Compiled regular expression for parsing oriented segments using '+' and '-'. Auto-initialized.
disable_progress_flag: bool
If True, progress bars are disabled. Defaults to False.
"""
gfa_path: Path
threads: int = 1
logger: logging.Logger = field(
default_factory=lambda: logging.getLogger("GraTools")
)
disable_progress_flag: Optional[bool] = False
gfa_name: Optional[str] = None
version: Optional[str] = None
header_gfa: List[str] = field(default_factory=list, repr=True)
sample_reference: Optional[str] = None
bam_segments_file: Optional[Path] = None
# GFA samples
processed_sample_names: Set[str] = field(init=False, repr=False)
samples_chrom_file_writer: Optional[TextIO] = field(init=False, repr=False, default=None)
samples_chrom_file_path: Optional[Path] = field(init=False, repr=False, default=None)
# Dictionaries populated during parsing
dict_samples_chrom: defaultdict = field(
default_factory=lambda: defaultdict(OrderedDict), repr=False
# Key: sample, Value: OrderedDict[chrom, list_of_fragments]
)
dict_segments_size: defaultdict = field(
default_factory=lambda: defaultdict(int), repr=False # Key: seg_id, Value: length
)
dict_segments_samples: defaultdict = field(
default_factory=lambda: defaultdict(list), repr=False # Key: seg_id, Value: list of ID to map to "sample;chrom;haplo"
)
# dict_samples_bed is related to AsyncBedWriter output, not directly populated here in GFA parsing
# It seems more like a tracker of where AsyncBedWriter IS writing files.
# The GFA class _generates_ data for AsyncBedWriter.
dict_samples_bed: defaultdict = field(
# This seems to track which BED files are expected/created by AsyncBedWriter.
default_factory=lambda: defaultdict(OrderedDict), repr=False
# Key: sample, Value: OrderedDict[chrom, path_to_bed_for_chrom_walk]
)
# Paths
works_path: Optional[Path] = None
bed_path: Optional[Path] = None
bam_path: Optional[Path] = None
found_minigraph: bool = False
import_links: bool = False # Flag to control link importing
db_links: Optional[AsyncGfaDatabase] = None # Instantiated in post_init
# Statistics counters
segment_count: int = 0
total_segment_length: int = 0
link_count: int = 0
degrees: defaultdict = field(default_factory=lambda: defaultdict(int), repr=False) # segment_id -> degree
walks_count: int = 0
max_walk_rank: int = 0 # Max number of segments in a walk
sum_rank0_length: int = 0 # Sum of lengths of first segments of walks
input_genome_size: int = 0
walks_info: List[Dict[str, Any]] = field(default_factory=list, repr=False)
inverted_links_count: int = 0
negative_links_count: int = 0
self_links_count: int = 0
isolated_segments: set = field(default_factory=set, repr=False) # Store segment IDs
# Internal operational attributes (initialized in post_init)
shared_executor: Optional[ThreadPoolExecutor] = None
progress: Optional[Progress] = None
line_type_counts: Counter = field(default_factory=Counter)
header_gfa_file: Optional[Path] = None
stats_file: Optional[Path] = None
db_file_path: Optional[Path] = None
RE_ORIENTED_SEG_GT_LT: re.Pattern = field(init=False, repr=False)
RE_ORIENTED_SEG_PLUS_MINUS: re.Pattern = field(init=False, repr=False)
def __post_init__(self) -> None:
"""
Initializes paths, directories, logger, database, and then triggers GFA processing.
"""
start_time = datetime.now()
if not self.gfa_path.exists() or not self.gfa_path.is_file():
self.logger.error(
f"GFA file '{self.gfa_path}' does not exist or is not a file."
)
raise FileNotFoundError(f"GFA file not found: {self.gfa_path}") # Fail fast
# Determine GFA name
# Using removesuffix for cleaner extension removal
name_to_process = self.gfa_path.name
if name_to_process.endswith(".gfa.gz"):
self.gfa_name = name_to_process.removesuffix(".gfa.gz")
elif name_to_process.endswith(".gfa"):
self.gfa_name = name_to_process.removesuffix(".gfa")
else:
self.gfa_name = self.gfa_path.stem # Fallback, might include partial extensions
self.logger.warning(
f"GFA file '{self.gfa_path.name}' has an unusual extension. Using stem: '{self.gfa_name}'.")
# Setup working paths
self.works_path = self.gfa_path.parent / f"{self.gfa_name}_GraTools-IMPORT"
self.works_path.mkdir(exist_ok=True, parents=True)
self.bed_path = self.works_path / "bed_files"
self.bed_path.mkdir(exist_ok=True, parents=True)
self.bam_path = self.works_path / "bam_files"
self.bam_path.mkdir(exist_ok=True, parents=True)
self.bam_segments_file = self.bam_path / f"{self.gfa_name}.bam"
self.header_gfa_file = self.works_path / f"header_{self.gfa_name}.txt"
self.stats_file = self.works_path / f"stats_{self.gfa_name}.txt"
self.db_file_path = self.works_path / f"links_{self.gfa_name}.db"
self.samples_chrom_file_path = self.works_path / "samples_chrom.txt"
# Precompile regex for optimization
self.RE_ORIENTED_SEG_GT_LT = re.compile(r"([<>])([^<>]+)")
self.RE_ORIENTED_SEG_PLUS_MINUS = re.compile(r"(.+?)([+-])")
# Initialize ThreadPoolExecutor.
self.shared_executor = ThreadPoolExecutor(
max_workers=self.threads if self.threads > 0 else None) # None for default workers
if self.import_links:
self.db_links = AsyncGfaDatabase(
db_file=self.db_file_path, # Use the defined path
timeout=5000.0 / 1000.0,
# Convert ms to seconds if timeout arg is seconds (original was 5000, assuming ms for busy_timeout)
# AsyncGfaDatabase timeout is in seconds for aiosqlite.connect
)
self.progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%",
"{task.completed}/{task.total} lines",
TimeElapsedColumn(),
TimeRemainingColumn(),
refresh_per_second=1, # Higher refresh can be costly
transient=True, # Progress bar disappears after completion
console=shared_console
)
self.processed_sample_names = set()
# Crรฉer et prรฉ-configurer l'objet segment rรฉutilisable
segment_template = AlignedSegment()
# Dรฉfinir tous les champs constants UNE SEULE FOIS
segment_template.flag = 4 # Unmapped
segment_template.reference_id = -1
segment_template.reference_start = 0
segment_template.mapping_quality = 0
segment_template.template_length = 0
# Stocker ce template prรฉ-configurรฉ
self.reusable_segment = segment_template
# Trigger GFA processing and subsequent steps
self.run() # This executes the main parsing and post-processing logic
# These are called after self.run() finishes its async and threaded tasks.
self.save_header()
self.tag_bam() # Assumes self.bam_segments_file is created by parse_gfa (via _parse_segment)
self.sort_file_in_place() # sorted file generated by _parse_walk_line
self.compute_statistics() # Uses data populated during parsing
@staticmethod
def _count_lines_fast(gfa_path: Path) -> int:
"""
Quickly counts the number of lines in a file (plain or gzipped)
by reading in large chunks and counting newline characters.
"""
# Utiliser `gfa_path.name` est plus robuste que de tester la chaรฎne complรจte
is_gzipped = gfa_path.name.endswith('.gz')
opener = gzip.open if is_gzipped else open
count = 0
try:
# 'rb' est crucial pour lire des bytes
with opener(gfa_path, 'rb') as f:
# Lire par blocs de 4MB
chunk_size = 4 * 1024 * 1024
while True:
chunk = f.read(chunk_size)
if not chunk:
break
count += chunk.count(b'\n')
except Exception as e:
logging.getLogger("GraTools").warning(f"Fast line count failed for {gfa_path}: {e}")
return 0 # ou lever l'exception
return count
[docs]
def sort_file_in_place(self):
"""
Sort a file in place using Unix commands without creating a temporary copy.
The file must be in the format: sample\tchromosome\tstart\tend
"""
output_path = self.works_path / "samples_chrom.txt"
try:
# Use Unix commands to sort the file in place
# First sort to stdout, then overwrite the original file
command = [
'sort',
'-t', '\t', # Tab separator
'-k1,1', # Sort by sample (col 1)
'-k2,2', # Then by chromosome (col 2)
'-k3,3n', # Then by start position (col 3) numerically
'-o', str(output_path), # Output to same file
str(output_path)
]
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
self.logger.warning(f"Error during sorting: {e}")
raise
def _process_single_bed(self, bed_path: Path, task_id: TaskID) -> None:
"""
Sorts a single BED file in place using BedTool.
Updates a Rich progress bar task.
This method is intended to be run in a separate thread or process.
"""
self.logger.debug(f"Sorting BED file: {bed_path.name}")
try:
tmp_path = bed_path.with_suffix(".sorted.bed")
sample_bed = BedTool(bed_path.as_posix()).sort(output=tmp_path.as_posix())
tmp_path.replace(bed_path)
self.logger.debug(f"Finished sorting BED file: {bed_path.name}")
except Exception as e: # Bedtools can raise various errors
self.logger.error(f"Error sorting BED file '{bed_path.name}': {e}", exc_info=True)
# How to signal error to progress? Maybe update description.
if self.progress:
self.progress.update(task_id, description=f"[red]Error sorting {bed_path.name}")
return # Exit on error for this BED file
if self.progress: # Check if progress object exists (it should if called from run())
self.progress.update(task_id, advance=1)
async def _read_gfa_lines_async(self):
"""
Asynchronously reads lines from a GFA file (plain or gzipped) by running
the synchronous file I/O in a separate thread.
It yields chunks of lines to minimize the overhead of context switching
between the reader and the parser, which significantly improves performance.
"""
is_gzipped = self.gfa_path.name.endswith(".gz")
# Determine the correct synchronous open function
open_func = gzip.open if is_gzipped else open
mode = "rt" if is_gzipped else "r"
log_msg = f"Reading {'gzipped' if is_gzipped else 'plain text'} GFA file: {self.gfa_path.name}"
self.logger.debug(f"{log_msg} using a dedicated I/O thread and line buffering.")
try:
# Run the synchronous open function in a separate thread to get the file iterator
file_iterator = await asyncio.to_thread(open_func, self.gfa_path, mode, encoding="utf-8")
except Exception as e:
self.logger.error(f"Failed to open file {self.gfa_path.name} in a thread: {e}")
return
try:
# --- Buffer logic re-introduced here ---
lines_buffer = []
# This size can be tuned. 1000-10000 is often a good range.
buffer_size = 100000
for line in file_iterator:
lines_buffer.append(line)
if len(lines_buffer) >= buffer_size:
yield lines_buffer # Yield a full chunk of lines
lines_buffer = [] # Reset the buffer
# Yield any remaining lines in the last, partially-filled buffer
if lines_buffer:
yield lines_buffer
except Exception as e:
self.logger.error(f"Error while reading from {self.gfa_path.name} in thread: {e}")
finally:
# Ensure the file is closed, also in a thread
self.logger.debug(f"Closing GFA file: {self.gfa_path.name}")
await asyncio.to_thread(file_iterator.close)
[docs]
async def parse_gfa(self) -> None:
"""
Parses the GFA file line by line, processing headers, segments, links, and walks.
Segments are written to a BAM file. Links are (optionally) stored in an async SQLite DB.
Walk information is used to generate data for BED files, written by AsyncBedWriter.
"""
# Standard GFA header for output BAM if input GFA has no SQ lines.
# Pysam's AlignmentFile requires a header. If GFA S lines imply @SQ, it would be better.
# For now, a minimal header.
bam_header = {
'HD': {'VN': '1.8', 'SO': 'unsorted'}, # SO: Coordinate or unsorted, depending on S lines
'SQ': []
# Placeholder for sequence dictionary. Ideally, populate from S lines if reference context is known.
}
# Pre-calculate total lines for progress bar (can be slow for very large files)
total_lines = None
if not self.disable_progress_flag:
try:
total_lines = self._count_lines_fast(self.gfa_path)
self.logger.info(f"Starting GFA parsing for: '{self.gfa_name}' with {total_lines} lines")
except Exception as e:
self.logger.warning(f"Could not count lines in GFA file for progress bar: {e}. Progress may be inaccurate.")
else:
self.logger.info(f"Starting GFA parsing for: '{self.gfa_name}' with unknown lines")
# Initialize AsyncBedWriter
bed_writer = AsyncBedWriter(
bed_dir=self.bed_path,
# batch_size=10000, # Tunable parameter
progress=self.progress # Pass Rich progress instance if needed by BedWriter
)
bed_writer.start() # Start its writer loop
# Connect to AsyncGfaDatabase if importing links
if self.import_links and self.db_links:
await self.db_links.connect()
links_batch: List[LinkInfo] = [] # Use List for type hinting, will store LinkInfo
link_batch_size = 100000 # How many links to batch before writing to DB
bam_segment_batch = []
BATCH_SIZE = 10000
# Add task to Rich Progress
parse_task_id = self.progress.add_task(
"GFA parsing...", total=total_lines, visible=not self.disable_progress_flag
)
self.samples_chrom_file_writer = open(self.samples_chrom_file_path, "w", encoding="utf-8")
# Open BAM file for writing segments
# pysam.AlignmentFile is synchronous. Writes happen in the main async thread.
# If BAM writing is slow, it can block the event loop.
# OPTIMIZATION_SUGGESTION: For very high throughput GFA S lines,
# consider batching AlignedSegment objects and writing them in a separate thread
# using `await asyncio.to_thread(bam_file_out.write, batched_segment)`.
try:
with (AlignmentFile(
self.bam_segments_file.as_posix(), "wb", header=bam_header, threads=self.threads
# Use self.threads for multi-threading if pysam supports for write
) as bam_file_out,
self.progress): # `with self.progress` starts/stops the Rich Progress context
current_line_type = None
lines_processed_count = 0
async for lines_chunk in self._read_gfa_lines_async():
for line_content in lines_chunk:
lines_processed_count += 1
line_content = line_content.strip() # Strip newline and surrounding whitespace
if not line_content or line_content.startswith("#"): # Skip empty or comment lines
continue
fields = line_content.split("\t")
line_type = fields[0]
# Mettre ร jour la description moins souvent aussi
if line_type != current_line_type:
current_line_type = line_type
self.progress.update(
parse_task_id,
description=f"GFA parsing: {current_line_type}"
)
self.line_type_counts[line_type] += 1
match line_type:
case "H": # Header line
self.header_gfa.append(line_content)
# Basic header parsing for version and reference sample (if present)
# Example H line: H VN:Z:1.0 RS:Z:ref_sample
for field in fields[1:]:
if field.startswith("VN:Z:"):
self.version = field[5:]
elif field.startswith(
"RS:Z:"): # Reference Sample tag (custom or specific GFA variant?)
self.sample_reference = field[5:]
case "S": # Segment line
segment_for_bam = self._parse_segment(fields, bam_file_out) # Pass fields list
if segment_for_bam:
bam_segment_batch.append(copy(segment_for_bam))
if len(bam_segment_batch) >= BATCH_SIZE:
for seg in bam_segment_batch:
bam_file_out.write(seg)
bam_segment_batch.clear()
case "L": # Links line
self.link_count += 1
if self.import_links and self.db_links:
link_info = self._parse_link_line(fields) # Pass fields list
if link_info:
links_batch.append(link_info)
if len(links_batch) >= link_batch_size:
await self.db_links.batch_insert_links(list(links_batch)) # Pass a copy
links_batch.clear()
case "W": # Walk line
# W <sample_name> <haplotype_index> <chr_name> <chr_start> <chr_stop> <segment_path>
# For GFA2 OrderedGroup (OG lines), parsing would be different. Assuming GFA1 W.
parse_result = await self._parse_walk_line(fields, bed_writer) # Pass fields list
# if parse_result:
# if self.walks_count % 10 == 0: # Every 2 walks
# await asyncio.sleep(0)
# Other GFA line types (P - Path, C - Containment, E - Edge (GFA2), U - Gap (GFA2), etc.)
# Add parsing logic here if needed. For now, they are counted by line_type_counts.
# update progress with number of line read on _read_gfe_lines_async
self.progress.advance(parse_task_id, 100000)
# if self.disable_progress_flag:
# self.logger.info(f"Processed {lines_processed_count} lines")
if bam_segment_batch:
for seg in bam_segment_batch:
bam_file_out.write(seg)
bam_segment_batch.clear()
except Exception as e:
self.logger.error(f"Error during GFA parsing main loop: {e}", exc_info=True)
# Ensure progress bar is stopped or marked as errored if possible
self.progress.update(parse_task_id, description=f"[red]Error during GFA parsing: {e}",
completed=lines_processed_count)
# Proceed to shutdown, resources might be in an inconsistent state.
finally:
# Finalize operations after loop (or if error occurred)
self.progress.remove_task(parse_task_id) # Or stop if not transient
self.logger.info(f"Finished reading GFA file. Processed {lines_processed_count} lines.")
if self.found_minigraph:
self.logger.warning(
"Sample 'MINIGRAPH' (or similar) detected and ignored. "
"See Cactus issue https://github.com/ComparativeGenomicsToolkit/cactus/pull/1050 for context."
)
# Flush any remaining links to the database
if self.import_links and self.db_links and links_batch:
self.logger.debug(f"Inserting final batch of {len(links_batch)} links into database.")
await self.db_links.batch_insert_links(list(links_batch))
# Shutdown AsyncGfaDatabase (waits for queue, creates indexes, closes connection)
if self.db_links:
self.logger.info(
f"Closing links database and saved to '{self.db_file_path.name if self.db_file_path else 'N/A'}'.")
await self.db_links.close()
# Shutdown AsyncBedWriter (waits for queue, flushes all buffers, closes files)
self.logger.info("Shutting down BED writer...")
await bed_writer.shutdown()
if self.samples_chrom_file_writer:
self.samples_chrom_file_writer.close()
self.logger.info(f"Sample-chromosome fragment data saved to '{self.samples_chrom_file_path.name}'.")
[docs]
def run(self):
"""
Synchronous entry point to orchestrate GFA parsing and BED file sorting.
Sets up an asyncio event loop and runs the asynchronous `parse_gfa` method.
Then, sorts the generated BED files using a ThreadPoolExecutor.
"""
# Create and configure a new asyncio event loop
loop = asyncio.new_event_loop() # uvloop policy is set globally, so new_event_loop() should pick it up.
asyncio.set_event_loop(loop)
# loop = asyncio.get_event_loop() # Get current event loop (uvloop should be set)
if self.shared_executor: # Ensure executor is set for the loop
loop.set_default_executor(self.shared_executor)
try:
# Run the asynchronous GFA parsing
loop.run_until_complete(self.parse_gfa())
self.logger.info("Async GFA parsing complete. Proceeding to sort BED files.")
# Collect paths of BED files to be sorted
# This assumes dict_samples_bed is populated correctly by where AsyncBedWriter saves files.
# TODO: AsyncBedWriter saves as self.bed_dir / f"{sample}.bed"
# So, we need to find all .bed files in self.bed_path that were potentially created.
# A more robust way: AsyncBedWriter could return a list of files it actually wrote/flushed.
# Shutdown shared thread executor
if self.shared_executor:
self.logger.debug("Shutting down shared ThreadPoolExecutor.")
self.shared_executor.shutdown(wait=True) # Wait for submitted tasks to complete
loop.close() # Only if this `run` method owns the loop lifecycle entirely.
# Assuming dict_samples_chrom keys are the samples for which BEDs might exist:
beds_to_sort = []
if self.bed_path and self.bed_path.exists():
for sample_name in self.processed_sample_names: # Samples derived from W lines
potential_bed_file = self.bed_path / f"{sample_name}.bed"
if potential_bed_file.exists() and potential_bed_file.stat().st_size > 0:
beds_to_sort.append(potential_bed_file)
if not beds_to_sort:
self.logger.info("No BED files found or specified for sorting.")
else:
self.logger.info(f"Found {len(beds_to_sort)} BED files to sort.")
sort_task_id = self.progress.add_task(
f"Sorting {len(beds_to_sort)} BED files",
total=len(beds_to_sort),
visible=not self.disable_progress_flag, # Make sure progress bar is visible
)
bed_sort_workers = min(self.threads, len(beds_to_sort))
with ThreadPoolExecutor(max_workers=bed_sort_workers) as bed_sort_specific_executor:
with self.progress:
futures = [
bed_sort_specific_executor.submit(
self._process_single_bed,
bed_file_path,
sort_task_id
)
for bed_file_path in beds_to_sort
]
# Attendre que toutes les tรขches soumises soient terminรฉes
# `wait` bloque jusqu'ร ce que les futures soient complรฉtรฉes.
done, not_done = wait(futures)
if not_done:
self.logger.warning(f"{len(not_done)} BED sorting tasks did not complete.")
for future_not_done in not_done:
try:
future_not_done.result(timeout=0.1) # Tenter d'obtenir une exception
except Exception as e:
self.logger.error(f"Task error for uncompleted sort: {e}", exc_info=True)
self.progress.remove_task(sort_task_id)
self.logger.info("All BED file sorting tasks complete.")
except Exception as e:
self.logger.error(f"Error during GFA.run execution: {e}", exc_info=True)
finally:
self.logger.debug("GFA.run() finished.")
@staticmethod
def _parse_gfa_walk_path_manually_iter(path_spec: str):
"""
Parse manuellement une chaรฎne de chemin GFA (ex: '>10<25>150') en utilisant un gรฉnรฉrateur.
C'est un remplacement optimisรฉ et plus rapide que re.finditer.
Yields:
Tuple[str, str]: Un tuple contenant (orientation_char, segment_id).
Ex: ('>', '10'), ('<', '25'), ('>', '150')
"""
if not path_spec:
return
segment_start_index = 1
for i in range(1, len(path_spec)):
char = path_spec[i]
if char == '>' or char == '<':
# On a trouvรฉ la fin du segment prรฉcรฉdent.
segment_id = path_spec[segment_start_index:i]
orientation_char = path_spec[segment_start_index - 1]
yield orientation_char, segment_id
# Le prochain segment commencera juste aprรจs le dรฉlimiteur actuel.
segment_start_index = i + 1
# Traiter le tout dernier segment aprรจs la fin de la boucle
if segment_start_index <= len(path_spec):
last_segment_id = path_spec[segment_start_index:]
if last_segment_id:
orientation_char = path_spec[segment_start_index - 1]
yield orientation_char, last_segment_id
async def _parse_walk_line(self, fields: List[str], bed_writer: AsyncBedWriter) -> Optional[str]:
"""
Parses a GFA Walk line (W line) to extract path information and generate BED-like entries.
Updates internal dictionaries related to samples, chromosomes, and segment occurrences.
Parameters
----------
fields : List[str]
A list of tab-separated fields from a W line.
Expected GFA1 format: W <sample> <hap_idx> <chr> <start> <stop> <path_spec>
Example path_spec: >seg1A+>seg1B->seg2A+ (GFA1 allows segment names with orientations)
>s1>s2<s3 (another common style where orientation is part of path_spec)
Returns
-------
Optional[str]
The sample name if successful, otherwise None.
"""
# Buffer local pour ce walk
bed_lines_batch = []
batch_size = 1000000 # Un batch de taille raisonnable
try:
# GFA1 W line: W <Samp> <HapIdx> <SeqId> <Start> <End> <Path>
# For GFA specification, check official docs. This parsing assumes a common interpretation.
_type_line = fields[0] # "W"
sample_name = fields[1]
haplotype_index_str = fields[2] # Can be numeric or other ID
chromosome_name = fields[3]
chromosome_start_str = fields[4] # 0-based or 1-based depending on GFA source convention
chromosome_stop_str = fields[5] # End coordinate
gfa_path_specification = fields[6] # e.g., "s1+s2-s3+" or ">s1<s2>s3"
except IndexError:
self.logger.warning(f"Malformed W line (not enough fields): '{fields}'. Skipping.")
return None
if "MINIGRAPH" in sample_name.upper(): # Case-insensitive check
self.found_minigraph = True
self.logger.debug(f"Skipping W line for 'MINIGRAPH' sample: {fields}")
return None # Skip processing for MINIGRAPH sample
# 1. Ajouter le nom du sample au set
self.processed_sample_names.add(sample_name)
# 2. รcrire dans le fichier
if self.samples_chrom_file_writer:
self.samples_chrom_file_writer.write(
f"{sample_name}\t{chromosome_name}\t{chromosome_start_str}\t{chromosome_stop_str}\n")
# Information for SW tag in BAM (segment_id -> list of "sample;chrom;haplo")
sample_chromosome_haplotype_id = f"{sample_name};{chromosome_name};{haplotype_index_str}"
oriented_segments_in_path: List[Tuple[str, str]] = []
# 1. Create the C-speed, memory-efficient iterator instead of a list.
# matches_iterator = self.RE_ORIENTED_SEG_GT_LT.finditer(gfa_path_specification)
oriented_segments_iterator = self._parse_gfa_walk_path_manually_iter(gfa_path_specification)
# 2. Handle the first segment separately to process its stats without building a list.
try:
first_match = next(oriented_segments_iterator)
except StopIteration:
# This means the path is empty or malformed.
self.logger.warning(
f"Could not parse any segments from path spec in W line: '{gfa_path_specification}'. Skipping.")
return None
# 3. Process the stats for the first segment.
# We initialize counters here now.
self.walks_count += 1
num_segments_in_walk = 0
current_walk_total_length_bp = 0
_dict_segments_samples = self.dict_segments_samples
_dict_segments_size = self.dict_segments_size
first_seg_id = first_match[1]
self.sum_rank0_length += _dict_segments_size.get(first_seg_id, 0)
# 4. Define a helper generator to elegantly combine the first element and the rest of the iterator.
all_segments_iterator = chain([first_match], oriented_segments_iterator)
# Prรฉ-calculer les parties constantes des lignes BED
current_segment_start_on_chr = int(chromosome_start_str)
bed_field_5_fragment_id = f"{chromosome_start_str}:{chromosome_stop_str}"
bed_line_prefix = f"{chromosome_name}\t"
bed_line_suffix = f"\t{bed_field_5_fragment_id}\t{haplotype_index_str}\n"
# This loop now iterates over the combined generator, not a pre-built list.
for orientation_char, segment_id in all_segments_iterator:
# Count on the fly.
num_segments_in_walk += 1
segment_orientation = "+" if orientation_char == ">" else "-"
_dict_segments_samples[segment_id].append(sample_chromosome_haplotype_id)
segment_length = _dict_segments_size.get(segment_id, 0)
# Warning for 0-length is already there, keep it.
if segment_length == 0:
self.logger.warning(
f"Segment '{segment_id}' in walk for sample '{sample_name}' has zero length "
f"or size not found. BED entry may be inaccurate."
)
current_walk_total_length_bp += segment_length
segment_end_on_chr = current_segment_start_on_chr + segment_length
# Crรฉez la ligne BED et envoyez-la immรฉdiatement
oriented_seg_id_for_bed = f"{segment_orientation}{segment_id}"
bed_line = (
bed_line_prefix
+ f"{current_segment_start_on_chr}\t{segment_end_on_chr}\t"
+ oriented_seg_id_for_bed
+ bed_line_suffix
)
# APPEL ASYNCHRONE ร la nouvelle mรฉthode
# await bed_writer.enqueue_single_line(sample_name, bed_line)
bed_lines_batch.append(bed_line)
# Envoyer le batch quand il est plein
if len(bed_lines_batch) >= batch_size:
await bed_writer.enqueue(sample_name, bed_lines_batch)
bed_lines_batch = [] # Vider le batch
current_segment_start_on_chr = segment_end_on_chr
# Envoyer le reste du batch ร la fin du walk
if bed_lines_batch:
await bed_writer.enqueue(sample_name, bed_lines_batch)
# Update total genome size covered by all walks
self.input_genome_size += current_walk_total_length_bp
self.max_walk_rank = max(self.max_walk_rank, num_segments_in_walk)
# Store summary information for this walk (for statistics)
self.walks_info.append({
"PathName": f"{sample_name}_{chromosome_name}_{haplotype_index_str}", # Unique path identifier
"SequenceLength_bp": current_walk_total_length_bp,
"NumberOfSegments": num_segments_in_walk,
})
return sample_name
def _parse_link_line(self, fields: List[str]) -> Optional[LinkInfo]:
"""
Parses a GFA Link line (L line) to extract connection information between two segments.
Updates link-related statistics.
Parameters
----------
fields : List[str]
A list of tab-separated fields from an L line.
Expected GFA1 format: L <seg1> <orient1> <seg2> <orient2> <overlap_CIGAR>
Returns
-------
Optional[LinkInfo]
A LinkInfo namedtuple containing parsed link data if successful, else None.
"""
try:
_type_line = fields[0] # "L"
seg_id_1 = fields[1]
orient_char_1 = fields[2] # "+" or "-"
seg_id_2 = fields[3]
orient_char_2 = fields[4] # "+" or "-"
# overlap_cigar = fields[5] # e.g., "0M", "10M", etc. Not used in LinkInfo current def.
except IndexError:
self.logger.warning(f"Malformed L line (not enough fields): '{fields}'. Skipping.")
return None
# Convert orientation characters to numeric values (+1 for '+', -1 for '-')
# This is a common convention if orientations need to be multiplied or compared numerically.
orient_seg_1_numeric = 1 if orient_char_1 == "+" else -1
orient_seg_2_numeric = 1 if orient_char_2 == "+" else -1
# Update link statistics
if seg_id_1 == seg_id_2:
self.self_links_count += 1
if orient_seg_1_numeric * orient_seg_2_numeric < 0: # e.g., (+1 * -1) = -1 (orientations differ)
self.inverted_links_count += 1
if orient_seg_1_numeric == -1 and orient_seg_2_numeric == -1: # Both negative
self.negative_links_count += 1
# Update segment degrees
self.degrees[seg_id_1] += 1
self.degrees[seg_id_2] += 1
# Remove linked segments from the set of isolated segments
self.isolated_segments.discard(seg_id_1)
self.isolated_segments.discard(seg_id_2)
# Create LinkInfo tuple. The `orient_key` fields seem to duplicate `orient_seg` fields.
# If they have a distinct meaning, LinkInfo definition or parsing might need adjustment.
return LinkInfo(
seg_id_1, orient_seg_1_numeric, orient_seg_1_numeric, # Using numeric orientation for keys too
seg_id_2, orient_seg_2_numeric, orient_seg_2_numeric
)
def _parse_segment(self, fields: List[str], bam_file_out: AlignmentFile) -> None:
"""
Parses a GFA Segment line (S line) and writes it as an AlignedSegment to a BAM file.
Updates segment-related statistics.
Parameters
----------
fields : List[str]
A list of tab-separated fields from an S line.
Expected GFA1 format: S <seg_id> <sequence> [optional_tags...]
Example optional tag: LN:i:<length> (though length is usually derived from sequence)
bam_file_out : pysam.AlignmentFile
Open BAM file object for writing.
"""
try:
_type_line = fields[0] # "S"
seg_id = fields[1]
sequence = fields[2] # Sequence string
# But here, it seems sequence is expected directly.
optional_tags_list = fields[3:] # List of "TAG:TYPE:VALUE" strings
except IndexError:
self.logger.warning(f"Malformed S line (not enough fields): '{fields}'. Skipping.")
return
seg_length = len(sequence)
# Update statistics
self.dict_segments_size[seg_id] = seg_length
self.segment_count += 1
self.total_segment_length += seg_length
self.isolated_segments.add(seg_id) # Add initially, links will remove it if connected
segment_to_write = self.reusable_segment
# 1. Donnรฉes du segment
segment_to_write.query_name = seg_id
segment_to_write.query_sequence = sequence
segment_to_write.query_qualities = qualitystring_to_array("*" * seg_length)
# 2. CIGAR (qui dรฉpend de la longueur)
# Store GFA optional tags in a custom BAM tag, e.g., ZG:Z:<concatenated_gfa_tags>
# Original code uses "SU" tag. Let's stick to that.
# The "SU" tag seems to store a comma-separated string of the GFA optional tags.
if seg_length > 0:
segment_to_write.cigartuples = [(0, seg_length)]
else:
segment_to_write.cigartuples = None
# 3. Tags (doivent รชtre effacรฉs et recrรฉรฉs)
segment_to_write.tags = [] # EFFACER les tags de l'itรฉration prรฉcรฉdente
if optional_tags_list:
gfa_tags_str = ",".join(optional_tags_list)
segment_to_write.set_tag("SU", gfa_tags_str, value_type='Z')
return segment_to_write
# Write to BAM file
# This is a synchronous call. If it's slow, it blocks the async event loop.
# try:
# bam_file_out.write(segment_to_write)
# except Exception as e: # Catch errors from pysam write (e.g., header issues, malformed segment)
# self.logger.error(f"Error writing segment '{seg_id}' to BAM: {e}", exc_info=True)
[docs]
def tag_bam(self) -> None:
"""
Tags the generated BAM segments file with sample walk information (SW tag).
This uses the `GratoolsBam` class to perform the tagging.
The original BAM segments file is overwritten with the tagged version.
"""
self.logger.info(f"Initiating BAM tagging for '{self.bam_segments_file.name}'.")
if not self.bam_segments_file or not self.bam_segments_file.exists():
self.logger.error("BAM segments file not found. Cannot perform tagging.")
return
if not self.dict_segments_samples:
self.logger.warning("dict_segments_samples is empty. No SW tags will be added.")
# Still proceed to index, as an untagged but indexed BAM might be expected.
# Instantiate GratoolsBam for tagging operation
bam_manager = GratoolsBam(
bam_path=self.bam_segments_file,
threads=self.threads,
logger=self.logger, # Pass parent logger
tagging=False, # Indicates that indexing might be needed after tagging
disable_progress_flag=self.disable_progress_flag
)
# The tag() method is expected to create a new tagged BAM,
# then replace the original self.bam_segments_file.
# It should also handle re-indexing if `tagging=True` implies index creation/update.
try:
# The returned path from tag() should be the path to the (now tagged) bam_segments_file
updated_bam_path = bam_manager.tag(dict_segments_samples=self.dict_segments_samples,
nb_segments=self.segment_count)
if updated_bam_path != self.bam_segments_file:
self.logger.warning(f"BAM tagging returned a new path '{updated_bam_path}', "
f"but expected overwrite of '{self.bam_segments_file.name}'. Check GratoolsBam.tag logic.")
except Exception as e:
self.logger.error(f"Error during BAM tagging process: {e}", exc_info=True)
def _get_connected_component(self, start_seg: str, cursor: sqlite3.Cursor) -> Set[str]:
"""
Executes a recursive SQL query via sqlite3 to retrieve all segments
connected to `start_seg`, using an existing database cursor.
This assumes bi-directional links (if A links to B, B effectively links to A for component finding).
Parameters
----------
start_seg : str
The starting segment ID for component traversal.
cursor : sqlite3.Cursor
An active SQLite database cursor to use for the query.
Returns
-------
Set[str]
A set of segment IDs belonging to the same connected component as `start_seg`.
"""
# The query finds all segments reachable from start_seg.
# It considers links in both directions (seg_id_1 -> seg_id_2 and seg_id_2 -> seg_id_1)
# by UNIONing them in the recursive part.
query = """
WITH RECURSIVE connected_component(segment_id) AS (
VALUES(?) -- Base case: the starting segment
UNION
-- Recursively find segments linked from current component segments
SELECT links.seg_id_2 FROM links
JOIN connected_component ON links.seg_id_1 = connected_component.segment_id
UNION -- Also consider links in the other direction
SELECT links.seg_id_1 FROM links
JOIN connected_component ON links.seg_id_2 = connected_component.segment_id
)
SELECT segment_id FROM connected_component;
"""
# OPTIMIZATION_SUGGESTION: For very large graphs, ensure SQLite's query planner handles
# this recursive CTE efficiently. Indexing on (seg_id_1, seg_id_2) and (seg_id_2, seg_id_1)
# or individual indexes on seg_id_1 and seg_id_2 (as created by AsyncGfaDatabase) is crucial.
try:
cursor.execute(query, (start_seg,))
rows = cursor.fetchall()
return {row[0] for row in rows} # Convert list of tuples to set of IDs
except sqlite3.Error as e:
self.logger.error(f"SQLite error getting connected component for '{start_seg}': {e}", exc_info=True)
return {start_seg} # Return at least the start segment on error
def _compute_segment_stats(self) -> dict[str, any]:
"""
Worker method to compute statistics related to segment lengths.
Uses heapq for memory-efficient 'top 5%' calculation.
"""
self.logger.debug("Starting segment length statistics computation...")
stats = {}
if not self.dict_segments_size:
return stats
# Materializing the list is unavoidable for median calculation.
# This is the main memory allocation for this specific task.
node_lengths = list(self.dict_segments_size.values())
stats["Median Segment Length (bp)"] = float(median(node_lengths))
# Optimization: Use heapq.nlargest to find the top 5% longest segments.
# This is more memory and CPU efficient than sorting the entire list for large N.
# It has a time complexity of O(N log k) vs O(N log N) for a full sort.
num_top_5_percent = max(1, int(0.05 * len(node_lengths)))
top_5_percent_nodes_lengths = heapq.nlargest(num_top_5_percent, node_lengths)
if top_5_percent_nodes_lengths:
stats["Avg Length of Top 5% Longest Segments (bp)"] = float(mean(top_5_percent_nodes_lengths))
stats["Median Length of Top 5% Longest Segments (bp)"] = float(median(top_5_percent_nodes_lengths))
# numpy.histogram is highly efficient for this task and is a reasonable choice.
length_bins = [0, 500, 1000, 2000, float("inf")]
bin_labels = ["0-499bp", "500-999bp", "1000-1999bp", "2000+bp"]
hist_counts, _ = histogram(node_lengths, bins=length_bins)
stats["Segment Length Distribution"] = dict(zip(bin_labels, map(int, hist_counts)))
self.logger.debug("Finished segment length statistics.")
return stats
def _compute_similarity_stats(self) -> dict[str, any]:
"""
Worker method to compute segment sharing (similarity) and occurrence (depth) statistics.
"""
self.logger.debug("Starting segment similarity/depth statistics computation...")
stats = {}
if not self.dict_segments_samples:
return stats
# Materializing lists is necessary for median calculation.
# Using generators here would not provide a memory benefit due to the median requirement.
similarities = [] # Number of unique samples per segment
depths = [] # Total occurrences of a segment across all walks
for sample_occurrences_list in self.dict_segments_samples.values():
depths.append(len(sample_occurrences_list))
# Reconstruire les noms d'รฉchantillons uniques directement depuis les chaรฎnes
unique_samples = {occ.split(";", 1)[0] for occ in sample_occurrences_list}
similarities.append(len(unique_samples))
if similarities:
similarities_np = np.array(similarities)
stats["Avg Unique Samples per Segment (Similarity Mean)"] = float(np.mean(similarities_np))
stats["Median Unique Samples per Segment (Similarity Median)"] = float(np.median(similarities_np))
# stdev requires at least two data points
stats["StdDev Unique Samples per Segment (Similarity Std)"] = float(np.std(similarities_np)) if len(
similarities) > 1 else 0.0
if depths:
depths_np = np.array(depths)
stats["Avg Occurrences per Segment (Depth Mean)"] = float(np.mean(depths_np))
stats["Median Occurrences per Segment (Depth Median)"] = float(np.median(depths_np))
stats["StdDev Occurrences per Segment (Depth Std)"] = float(np.std(depths_np)) if len(depths) > 1 else 0.0
self.logger.debug("Finished segment similarity/depth statistics.")
return stats
def _compute_connectivity_stats(self) -> dict[str, any]:
"""
Worker method to compute graph connectivity statistics by querying the SQLite database.
This is often the most time-consuming statistics task.
"""
self.logger.debug("Starting connectivity statistics computation...")
stats = {
"Number of Connected Components (CCs)": 0,
"Largest CC Size (bp)": 0,
"Number of Disconnected CCs (excluding largest)": 0,
"Total Length of Disconnected CCs (bp)": 0
}
if not (self.import_links and self.db_links and self.db_links.db_file and self.db_links.db_file.exists()):
self.logger.info("Link importing disabled or DB not found; skipping connectivity stats.")
return stats
conn_sqlite = None
try:
# Each thread requires its own database connection.
# Use read-only mode for safety and potential performance gains.
db_uri = f"file:{self.db_links.db_file.as_posix()}?mode=ro"
conn_sqlite = sqlite3.connect(db_uri, uri=True, timeout=5.0)
cursor = conn_sqlite.cursor()
cursor.execute("PRAGMA query_only = ON;")
all_components_nodes = []
visited_segments = set()
# Iterate over all known segment IDs to find all connected components
for seg_id in self.dict_segments_size.keys():
if seg_id not in visited_segments:
component = self._get_connected_component(seg_id, cursor)
if component:
all_components_nodes.append(component)
visited_segments.update(component)
if not all_components_nodes:
return stats
stats["Number of Connected Components (CCs)"] = len(all_components_nodes)
# Calculate the size in base pairs of each component
component_lengths_bp = [sum(self.dict_segments_size.get(s, 0) for s in comp) for comp in all_components_nodes]
if component_lengths_bp:
largest_cc_size_bp = max(component_lengths_bp)
stats["Largest CC Size (bp)"] = largest_cc_size_bp
# All components that are not the largest one
disconnected_lengths = [length for length in component_lengths_bp if length < largest_cc_size_bp]
stats["Number of Disconnected CCs (excluding largest)"] = len(disconnected_lengths)
stats["Total Length of Disconnected CCs (bp)"] = sum(disconnected_lengths)
except sqlite3.Error as e:
self.logger.error(f"SQLite error during connectivity stats: {e}", exc_info=True)
finally:
if conn_sqlite:
conn_sqlite.close()
self.logger.debug("Finished connectivity statistics.")
return stats
[docs]
def compute_statistics(self) -> dict[str, any]:
"""
Orchestrates the computation of GFA graph statistics in parallel and saves the results.
This method acts as a dispatcher, running heavy calculations in separate threads.
"""
self.logger.info("Computing GFA graph statistics in parallel...")
final_stats = {}
# Define the computational tasks to be run in parallel
tasks_to_run = {
"Segment Statistics": self._compute_segment_stats,
"Segment Sharing & Depth": self._compute_similarity_stats,
"Graph Structure": self._compute_connectivity_stats, # Usually the longest task
}
total_steps = len(tasks_to_run) + 1 # +1 for final aggregation and formatting
stats_task_id = self.progress.add_task(
"[bold blue]Computing graph statistics...", total=total_steps, visible=not self.disable_progress_flag
)
with self.progress, ThreadPoolExecutor(max_workers=self.threads) as executor:
# Submit all tasks to the thread pool
future_to_category = {executor.submit(func): category for category, func in tasks_to_run.items()}
# Process results as they complete
for future in as_completed(future_to_category):
category = future_to_category[future]
try:
result_data = future.result()
final_stats[category] = result_data
self.logger.info(f"โ Statistics for '{category}' computed successfully.")
except Exception as e:
self.logger.error(f"Error computing stats for category '{category}': {e}", exc_info=True)
finally:
# Advance the progress bar for each completed task
self.progress.update(stats_task_id, advance=1)
# --- Final Aggregation and Fast Calculations (run in the main thread) ---
self.logger.info("Aggregating final statistics...")
# 1. Graph Overview (fast, no thread needed)
final_stats["Graph Overview"] = {
"GFA File Name": self.gfa_name,
"GFA Version": self.version,
"Total Segments (S lines)": self.segment_count,
"Total Links (L lines)": self.link_count,
"Total Walks (W lines)": self.walks_count,
"Unique Samples in Walks": len(self.processed_sample_names),
}
# 2. Basic Segment and Link Stats (fast)
avg_segment_length = (self.total_segment_length / self.segment_count if self.segment_count else 0.0)
final_stats.setdefault("Segment Statistics", {})["Total Segment Length (bp)"] = self.total_segment_length
final_stats.setdefault("Segment Statistics", {})["Average Segment Length (bp)"] = avg_segment_length
avg_degree = (2 * self.link_count / self.segment_count if self.segment_count else 0.0)
max_degree = max(self.degrees.values()) if self.degrees else 0
final_stats["Link Statistics"] = {
"Max Segment Degree": max_degree,
"Average Segment Degree": avg_degree,
"Self-Links (S1 -> S1)": self.self_links_count,
"Inverted Links (S1+ -> S2-)": self.inverted_links_count,
"Both Negative Links (S1- -> S2-)": self.negative_links_count,
}
# 3. Path Stats (fast)
graph_size_bp = self.total_segment_length
compress_ratio = self.input_genome_size / graph_size_bp if graph_size_bp else 0.0
final_stats["Path (Walk) Statistics"] = {
"Total Length of All Paths (bp, input_genome_size)": self.input_genome_size,
"Graph Compression Ratio (input_genome_size / graph_size_bp)": compress_ratio,
"Max Segments in a Single Walk": self.max_walk_rank,
"Sum of First Segment Lengths in Walks": self.sum_rank0_length,
}
# 4. Complete Graph Structure Stats (fast)
graph_density = (
2 * self.link_count / (self.segment_count * (self.segment_count - 1)) if self.segment_count > 1 else 0.0)
dead_ends_count = sum(1 for degree in self.degrees.values() if degree == 1)
final_stats.setdefault("Graph Structure", {}).update({
"Graph Density": graph_density,
"Segments/Links Ratio": self.segment_count / self.link_count if self.link_count else 0.0,
"Dead-End Segments (degree 1)": dead_ends_count,
"Isolated Segments (degree 0)": len(self.isolated_segments),
})
# Update progress bar for the final aggregation step
self.progress.update(stats_task_id, advance=1, description="[green]Formatting & Saving")
# --- Formatting and Writing to Output File ---
flat_stats_for_df = []
# Sort categories for consistent output order
for category in sorted(final_stats.keys()):
metrics_dict = final_stats[category]
if isinstance(metrics_dict, dict):
# Sort metrics within each category for consistent output order
for metric_name in sorted(metrics_dict.keys()):
value = metrics_dict[metric_name]
flat_stats_for_df.append((category, metric_name, value))
df_stats = DataFrame(flat_stats_for_df, columns=["Category", "Metric", "Value"])
def format_value(x):
if isinstance(x, dict):
return ", ".join(f"{k}: {v}" for k, v in x.items())
if isinstance(x, float):
return f"{x:.2e}" if (0 < abs(x) < 1e-3 or abs(x) > 1e6) else f"{x:.2f}"
if isinstance(x, int):
return f"{x:,}"
return str(x)
df_stats["Value"] = df_stats["Value"].apply(format_value)
if self.stats_file:
try:
df_stats.to_csv(self.stats_file, sep="\t", index=False)
self.logger.info(f"Saved GFA statistics to '{self.stats_file.name}'")
except IOError as e:
self.logger.error(f"Failed to save statistics CSV '{self.stats_file}': {e}")
self.progress.remove_task(stats_task_id)
return final_stats