Source code for gratools.Gratools
# Standard library imports
import csv
import gzip
import logging
import multiprocessing # For Manager and ProcessPoolExecutor
import subprocess
import sys # For sys.argv
from collections import OrderedDict, defaultdict # Keep only if directly used by Gratools methods
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed # For parallelize_samples
from dataclasses import dataclass, field
from pathlib import Path
from shutil import rmtree, which
from tempfile import NamedTemporaryFile
from typing import Any, Dict, List, Optional, Tuple, Set # Added Set, Tuple
# Third-party imports
import pandas as pd
from pybedtools import BedTool
from Bio import SeqIO, SeqRecord # For generate_fasta
from rich.align import Align # For display_segment_by_depth
from rich.console import Console, Group # For parallelize_samples progress
from rich.live import Live # For parallelize_samples progress
from rich.panel import Panel # For parallelize_samples progress
from rich.progress import ( # For parallelize_samples progress
BarColumn,
Progress,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table # For display methods
from rich import box
from natsort import natsorted, natsort_keygen
from pybedtools import cleanup
# Local application/library specific imports
# Assuming these are correctly in .Graph module or same directory
from .Graph import GFA, GratoolsBam, SubGraph
from .__init__ import __version__, header_tool # Application version and header
from .logger_config import configure_logger, shared_console, update_logger_file_suffix
# Pandas display options
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
[docs]
def flatten(list_of_lists: List[List[Any]]) -> List[Any]:
"""
Flattens a list of lists into a single list.
Parameters
----------
list_of_lists : List[List[Any]]
A list where each element is itself a list.
Returns
-------
List[Any]
A new list containing all items from the sublists.
"""
return [item for sublist in list_of_lists for item in sublist]
[docs]
@dataclass
class Gratools:
"""
Main class for the GraTools toolkit, orchestrating GFA file processing,
subgraph extraction, and various analyses on genomic graph data.
It handles GFA importing (delegating to the GFA class), manages input parameters,
and provides an interface for operations like subgraph extraction, FASTA generation,
and statistical analysis of graph components.
Attributes
----------
gfa_path : Path
Path to the input GFA file.
threads : int, optional
Number of threads for parallelizable operations. Defaults to 1.
outdir : Optional[Path], optional
Output directory for GraTools results. If None, defaults to a directory
named `GraTools-output_{gfa_name}` in the same directory as `gfa_path`.
logger : Optional[logging.Logger]
Logger instance. Auto-configured in `__post_init__`.
gfa_name : Optional[str]
Name of the GFA file, derived from `gfa_path` without extensions. Auto-initialized.
bam_segments_file : Optional[Path]
Path to the BAM file containing GFA segments, located within the import directory. Auto-initialized.
dict_samples_chrom : defaultdict[str, OrderedDict[str, List[Tuple[str, str]]]]
Maps sample names to an OrderedDict of chromosome names, which maps to a list of
(start_fragment, stop_fragment) string tuples. Populated from `samples_chrom.txt`.
works_path : Optional[Path]
Path to the main GraTools output directory for the current run (e.g., `outdir/GraTools-output_{gfa_name}`). Auto-initialized.
bed_path : Optional[Path]
Path to the BED files subdirectory within the GFA import directory. Auto-initialized.
bam_path : Optional[Path]
Path to the BAM files subdirectory within the GFA import directory. Auto-initialized.
samples_chrom_path : Optional[Path]
Path to the `samples_chrom.txt` file within the GFA import directory. Auto-initialized.
dict_gfa_graph_object : Dict[str, SubGraph]
Dictionary mapping sample names to their corresponding `SubGraph` objects
after extraction. Defaults to an empty dict.
sample_name_query : Optional[str]
Name of the primary sample for query operations (e.g., subgraph extraction). Defaults to None.
chromosome_query : Optional[str]
Chromosome identifier for query operations. Defaults to None.
start_query : int
Start position for query operations (0-based). Defaults to 0.
stop_query : Optional[int]
Stop position for query operations. If None, might be inferred as chromosome end. Defaults to None.
suffix : Optional[str]
Custom suffix for output files. If None, a default suffix based on query parameters is generated. Auto-initialized.
build_fasta_flag : bool
Flag to enable FASTA file generation during subgraph extraction. Defaults to False.
gzip_gfa : bool
Flag indicating if the input GFA file is gzipped. Auto-detected.
merge : Optional[int]
Merge distance (`-d` for `bedtools merge`) for BED region processing.
If -1 and query region is set, defaults to 10% of query region size. Defaults to None.
meta : Dict[str, Any]
Dictionary for meta-parameters like verbosity, log_path, threads, passed from CLI or config.
Defaults to an empty dict.
import_links : bool
Flag to control whether GFA links are imported into a database during GFA parsing. Defaults to True.
debug : bool
Flag to enable debug mode, typically for more verbose logging or error details. Defaults to False.
import_path : Optional[Path]
Path to the GFA import directory (`{gfa_name}_GraTools_IMPORT`). Auto-initialized.
header_gfa_file : Optional[Path]
Path to the saved GFA header file within the import directory. Auto-initialized.
stats_gfa_file : Optional[Path]
Path to the saved GFA statistics file within the import directory. Auto-initialized.
sub_graph_query : Optional[SubGraph]
SubGraph object for the primary query sample. Initialized in `extract_sub_graph`.
_cached_chromosome_data : Optional[pd.DataFrame] # Attribute for caching chromosome data
Internal cache for data read from `samples_chrom_path` to avoid redundant parsing.
disable_progress_flag: Optional[bool]
Flag to control progress bar visibility. Defaults to False.
"""
gfa_path: Path
threads: int = 1
outdir: Optional[Path] = None
logger: Optional[logging.Logger] = None
gfa_name: Optional[str] = None
bam_segments_file: Optional[Path] = None
dict_samples_chrom: defaultdict = field(default_factory=lambda: defaultdict(OrderedDict), repr=False)
works_path: Optional[Path] = None
bed_path: Optional[Path] = None
bam_path: Optional[Path] = None
samples_chrom_path: Optional[Path] = None
dict_gfa_graph_object: Dict[str, SubGraph] = field(default_factory=dict, repr=False)
sample_name_query: Optional[str] = None
chromosome_query: Optional[str] = None
start_query: int = 0 # Default to 0, assuming 0-based start
stop_query: Optional[int] = None
suffix: Optional[str] = None
build_fasta_flag: bool = False
gzip_gfa: bool = field(init=False, default=False) # Auto-detected
merge: Optional[int] = None
meta: Dict[str, Any] = field(default_factory=dict, repr=False)
import_links: bool = False # Passed to GFA class
debug: bool = False
disable_progress_flag: bool = False
# Attributes initialized in post_init or later methods
import_path: Optional[Path] = field(init=False, default=None)
header_gfa_file: Optional[Path] = field(init=False, default=None)
stats_gfa_file: Optional[Path] = field(init=False, default=None)
sub_graph_query: Optional[SubGraph] = field(init=False, default=None, repr=False)
_cached_chromosome_data: Optional[pd.DataFrame] = field(init=False, default=None, repr=False)
def __post_init__(self):
"""
Initializes paths, directories, logging, and performs essential checks and setup.
This method is automatically called after the dataclass is initialized.
"""
self._derive_gfa_name_and_type()
self._setup_paths()
self._configure_logging() # Configure logger early
self._log_initial_info()
self.disable_progress_flag = self.meta.get("disable_progress_flag", False)
self.logger.info("Progress bar is " + ("disabled" if self.disable_progress_flag else "enabled") + f".")
self._check_dependencies_and_import() # This might call GFA() which uses logger
self._read_samples_chrom_data()
self._validate_query_parameters()
self._validate_bed_files_exist()
self._process_merge_parameter()
self._finalize_suffix()
if self.suffix: # Suffix might have changed, update logger if needed
# Ensure logger is not None before updating
if self.logger:
self.logger = update_logger_file_suffix(self.logger, self.suffix)
self.logger.info(f"Output file suffix set to: '{self.suffix}'")
else: # Should not happen if _configure_logging was successful
print("CRITICAL: Logger not initialized before suffix finalization.", file=sys.stderr)
def _derive_gfa_name_and_type(self) -> None:
"""Derives GFA name and detects if it's gzipped from gfa_path."""
if not self.gfa_path.is_file():
# Logger might not be configured yet if this fails very early.
# Using basic print to stderr for critical early failures.
print(f"ERROR: Input GFA file not found at specified path: {self.gfa_path}", file=sys.stderr)
raise FileNotFoundError(f"GFA file not found: {self.gfa_path}")
name = self.gfa_path.name
if name.endswith(".gfa.gz"):
self.gfa_name = name.removesuffix(".gfa.gz")
self.gzip_gfa = True
elif name.endswith(".gfa"):
self.gfa_name = name.removesuffix(".gfa")
self.gzip_gfa = False
else:
# Fallback for non-standard GFA extensions
self.gfa_name = Path(name).stem # e.g., "my.gfa" from "my.gfa.tar.gz" -> "my.gfa.tar" -> "my"
# For "file.gfa.gz", Path(name).stem is "file.gfa"
# To get "file" from "file.gfa.gz": Path(Path(name).stem).stem
self.gzip_gfa = ".gz" in self.gfa_path.suffixes
# Early warning if logger isn't ready
print(
f"WARNING: GFA file '{name}' has non-standard extension. Inferred gfa_name='{self.gfa_name}', gzip={self.gzip_gfa}.",
file=sys.stderr)
def _setup_paths(self) -> None:
"""Sets up working directory paths based on gfa_name and outdir."""
if not self.gfa_name:
print("CRITICAL ERROR: gfa_name not derived prior to path setup. Aborting.", file=sys.stderr)
raise ValueError("gfa_name is essential for path setup and was not derived.")
# Determine base output directory
base_outdir = Path(self.outdir) if self.outdir else self.gfa_path.parent
self.works_path = base_outdir / f"GraTools-output_{self.gfa_name}"
try:
self.works_path.mkdir(parents=True, exist_ok=True)
except OSError as e:
print(f"CRITICAL ERROR: Failed to create working directory {self.works_path}: {e}", file=sys.stderr)
raise
# Paths related to the GFA import (potentially shared across runs)
self.import_path = self.gfa_path.parent / f"{self.gfa_name}_GraTools-IMPORT"
self.bed_path = self.import_path / "bed_files"
self.bam_path = self.import_path / "bam_files"
# Specific import files (derived from import_path and gfa_name)
self.bam_segments_file = self.bam_path / f"{self.gfa_name}.bam"
self.header_gfa_file = self.import_path / f"header_{self.gfa_name}.txt"
self.stats_gfa_file = self.import_path / f"stats_{self.gfa_name}.txt"
self.samples_chrom_path = self.import_path / "samples_chrom.txt"
def _configure_logging(self) -> None:
"""Configures the logger instance for Gratools."""
# Default log directory to works_path if not specified in meta
log_path_dir = self.meta.get("log_file_directory") or self.works_path
if not log_path_dir: # Should not happen if works_path is set
print("CRITICAL: Log path directory could not be determined.", file=sys.stderr)
# Fallback to a very basic logger or raise error. For now, let configure_logger handle it.
# Ensure log directory exists
try:
Path(log_path_dir).mkdir(parents=True, exist_ok=True)
except OSError as e:
print(f"WARNING: Failed to create log directory {log_path_dir}: {e}. Logs might not be saved.",
file=sys.stderr)
# Allow continuation, configure_logger might handle logging to console only.
verbosity = self.meta.get("log_verbosity_level", "INFO").upper()
if verbosity == "DEBUG":
self.debug = True # Set debug flag for other parts of the code
# Initial suffix for log filename might be empty or a base suffix from meta
initial_log_suffix = self.meta.get("log_suffix_base", self.suffix if self.suffix else "")
# `configure_logger` is expected to return a configured logger instance
self.logger = configure_logger("GraTools", Path(log_path_dir), verbosity, initial_log_suffix)
if not self.logger: # Should not happen if configure_logger is robust
print("CRITICAL: Logger configuration failed.", file=sys.stderr)
# Fallback to basic stdout logging if really necessary
logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)-7s | %(message)s')
self.logger = logging.getLogger("GraTools_fallback")
def _log_initial_info(self) -> None:
"""Logs startup information using the configured logger."""
# This assumes self.logger is now configured and available.
shared_console.print(f"\n{header_tool}\n") # Print header to terminal
self.logger.info(f"Starting GraTools v{__version__}")
self.logger.info(f"Command line: {' '.join(sys.argv)}")
self.logger.info(f"Effective verbosity: '{self.meta.get('log_verbosity_level', 'INFO').upper()}'")
# Log effective threads: from meta if present, else from self.threads (dataclass default)
self.threads = self.meta.get("num_threads", self.threads)
self.logger.info(f"Number of threads: {self.threads}")
self.logger.info(f"Input GFA file: '{self.gfa_path}'")
self.logger.info(f"Derived GFA Name: '{self.gfa_name}' (Gzipped: {self.gzip_gfa})")
self.logger.info(f"Main output directory (works_path): {self.works_path}")
self.logger.info(f"GFA import directory (import_path): {self.import_path}")
log_file_dir = self.meta.get("log_file_directory") or self.works_path
self.logger.info(f"Log file directory: {log_file_dir}")
def _check_dependencies_and_import(self) -> None:
"""Checks for external dependencies (bedtools) and ensures GFA import files exist."""
self._check_bedtools_availability()
self._ensure_gratools_import_files() # This might trigger GFA parsing
def _read_samples_chrom_data(self) -> None:
"""Reads sample-chromosome data from the `samples_chrom.txt` file."""
if not self.samples_chrom_path or not self.samples_chrom_path.exists():
self.logger.error(
f"Samples chromosome file '{self.samples_chrom_path}' not found. "
"This file is essential for many operations and is created during GFA importing."
)
# Consider raising an error if this file is critical for subsequent operations.
return
self.logger.debug(f"Reading samples and chromosomes from '{self.samples_chrom_path.name}'.")
try:
with open(self.samples_chrom_path, "r", encoding="utf-8") as file:
for line_num, line_content in enumerate(file, 1):
parts = line_content.strip().split("\t")
if len(parts) == 4: # Expected: sample, chrom, fragment_start, fragment_stop
sample, chrom, frag_start, frag_stop = parts
# Ensure sub-dictionaries are OrderedDicts if order matters
if chrom not in self.dict_samples_chrom[sample]:
self.dict_samples_chrom[sample][chrom] = [] # Initialize list for this chrom
self.dict_samples_chrom[sample][chrom].append((frag_start, frag_stop))
else:
self.logger.warning(
f"Malformed line #{line_num} in '{self.samples_chrom_path.name}': "
f"Expected 4 fields, got {len(parts)}. Line: '{line_content.strip()}'"
)
except IOError as e:
self.logger.error(f"Error reading samples_chrom file '{self.samples_chrom_path}': {e}", exc_info=self.debug)
except Exception as e: # Catch other potential errors during parsing
self.logger.error(f"Unexpected error parsing '{self.samples_chrom_path}': {e}", exc_info=self.debug)
if not self.dict_samples_chrom:
self.logger.warning(
f"No sample-chromosome data loaded from '{self.samples_chrom_path.name}'. File might be empty or incorrectly formatted.")
def _validate_query_parameters(self) -> None:
"""Validates query-specific parameters against loaded GFA data."""
if not self.sample_name_query:
# No query sample specified, so no further validation needed for query params.
return
self._check_sample_in_gfa(self.sample_name_query)
if not self.chromosome_query:
# No query chromosome specified, cannot validate start/stop.
return
self._check_chromosome_in_gfa(self.sample_name_query, self.chromosome_query)
chrom_size = self.get_chromosome_size(self.sample_name_query, self.chromosome_query)
if chrom_size is None: # Chromosome not found or size indeterminate
self.logger.error(
f"Could not determine size for query chromosome '{self.chromosome_query}' in sample '{self.sample_name_query}'. "
"Stop query validation might be unreliable or skipped."
)
return # Cannot validate stop_query
if self.stop_query is None: # If stop_query not provided, default to chromosome size
self.logger.info(
f"Query stop position not specified. Defaulting to end of chromosome '{self.chromosome_query}' ({chrom_size}bp).")
self.stop_query = chrom_size
elif self.stop_query > chrom_size:
self.logger.warning(
f"Query stop position ({self.stop_query}) exceeds chromosome '{self.chromosome_query}' size ({chrom_size}). "
f"Adjusting stop_query to {chrom_size}."
)
self.stop_query = chrom_size
# Validate start vs stop (start must be less than stop for a valid region)
if self.start_query >= self.stop_query:
self.logger.error(
f"Invalid query range: start position ({self.start_query}) must be less than stop position ({self.stop_query}) "
f"for chromosome '{self.chromosome_query}'. Please correct query parameters."
)
def _validate_bed_files_exist(self) -> None:
"""Validates the presence of BED files for each sample found in the GFA import."""
if not self.bed_path or not self.bed_path.exists():
self.logger.warning(f"BED files directory '{self.bed_path}' does not exist. Cannot validate BED files.")
return
self.logger.debug(f"Checking for existence of BED files in '{self.bed_path}' for importing samples.")
for sample_name in self.dict_samples_chrom.keys(): # Samples derived from GFA import
bed_file_path = self.bed_path / f"{sample_name}.bed"
if not bed_file_path.exists():
self.logger.warning(
f"BED file for sample '{sample_name}' not found at '{bed_file_path}'. "
"This may cause issues if sample-specific operations are performed."
)
elif bed_file_path.stat().st_size == 0:
self.logger.warning(
f"BED file for sample '{sample_name}' at '{bed_file_path}' is empty."
)
def _process_merge_parameter(self) -> None:
"""
Processes the `self.merge` parameter. If it's -1 (auto-calculate) and a query region
is defined, it calculates merge as 10% of the query region size.
Logs the effective merge value.
"""
log_message = None
if self.merge == -1: # Sentinel for auto-calculation
if self.sample_name_query and self.chromosome_query and \
self.start_query is not None and self.stop_query is not None and \
self.stop_query > self.start_query:
query_size = self.stop_query - self.start_query
self.merge = max(0, int((query_size * 100) / 100)) # 100% of query size, ensure non-negative
log_message = (
f"Auto-calculated merge distance for `bedtools merge -d`: {self.merge} bp "
f"(100% of query region size {query_size}bp)."
)
else:
# Cannot auto-calculate, reset merge to a safe default (e.g., 0 or None if bedtools handles it)
# If merge=0 is equivalent to no -d, that's fine. Otherwise, set to None.
# Bedtools merge -d 0 usually means only abutting features.
self.merge = 0
log_message = (
"Merge distance auto-calculation (-1) requested, but query region is not fully defined. "
f"Defaulting merge distance to {self.merge} bp (abutting features only)."
)
elif self.merge is not None and self.merge < 0:
self.logger.error(f"Invalid negative merge distance provided ({self.merge}bp).")
elif self.merge is not None: # User provided a valid merge distance
log_message = f"User-defined merge distance for `bedtools merge -d`: {self.merge} bp."
# If self.merge is None, no specific logging needed here unless it's an issue for downstream.
if log_message and self.logger: # Ensure logger is initialized
self.logger.info(log_message)
def _finalize_suffix(self) -> None:
"""
Finalizes the output file suffix (`self.suffix`).
If not provided, a default suffix is generated based on query parameters and merge distance.
Ensures the suffix starts with an underscore if it's not empty.
"""
if not self.suffix: # If suffix is None or empty string
if self.sample_name_query and self.chromosome_query and \
self.start_query is not None and self.stop_query is not None:
# Construct suffix from query parameters
query_part = f"{self.sample_name_query}-{self.chromosome_query}-{self.start_query}-{self.stop_query}"
merge_part = f"d{self.merge}" if self.merge is not None else "d_auto" # Or d0 if merge is 0
self.suffix = f"_{query_part}-{merge_part}"
else:
# No query, or incomplete query, use a generic suffix or just based on merge if set
merge_part = f"_d{self.merge}" if self.merge is not None else ""
self.suffix = f"{merge_part}" # Or simply "" if no default needed
# Ensure suffix starts with an underscore if it's not already empty
if self.suffix and not self.suffix.startswith("_"):
self.suffix = f"_{self.suffix}"
# If suffix ended up being just "_" from an empty string initially, make it empty.
if self.suffix == "_":
self.suffix = ""
# Logging of the final suffix is handled in __post_init__ after this call.
def _run_with_shell_wrapper(self) -> None:
"""Helper method to run BEDtools using a shell, specifically for wrapper scripts."""
try:
bedtools_path = which("bedtools")
if not bedtools_path:
self.logger.critical("BEDtools wrapper found but path is not valid.")
raise EnvironmentError("BEDtools not functional.")
process = subprocess.run(
f"{bedtools_path} --version",
shell=True,
check=True,
capture_output=True,
text=True
)
self.logger.info(f"BEDtools found via shell wrapper: {process.stdout.strip()}")
except Exception as inner_e:
self.logger.critical(f"BEDtools test failed with shell=True: {inner_e}")
raise EnvironmentError("BEDtools not found or not functional.")
def _check_bedtools_availability(self) -> None:
"""
Checks if the BEDtools command-line interface is installed and accessible.
It first attempts a direct execution, and if that fails with an 'Exec format error',
it falls back to a shell-based execution, which is common for wrapper scripts in HPC environments.
"""
self.logger.debug("Checking for BEDtools availability.")
# Attempt 1: Standard, secure execution (shell=False)
try:
process = subprocess.run(
["bedtools","--version"],
shell=False,
check=True,
capture_output=True,
text=True
)
self.logger.info(f"BEDtools found: {process.stdout.strip()}")
except FileNotFoundError:
self.logger.critical("BEDtools command not found. Please install it and ensure it's in your PATH.")
except OSError as e:
if e.errno == 8: # 8 is the error code for 'Exec format error'
self.logger.warning(
"OSError: [Errno 8] Exec format error. Retrying with shell=True for wrapper script.")
self._run_with_shell_wrapper()
else:
self.logger.critical(f"An unexpected OSError occurred: {e}")
except subprocess.CalledProcessError as e:
self.logger.critical(
f"BEDtools command failed (exit code {e.returncode}): {e.stderr.strip()}"
)
except Exception as e:
self.logger.critical(f"Unexpected error while checking for BEDtools: {e}", exc_info=True)
def _ensure_gratools_import_files(self) -> None:
"""
Ensures that GraTools import files exist. If not, triggers GFA parsing to create them.
Also validates existing import files for completeness.
"""
self.logger.debug(f"Checking GraTools import directory: '{self.import_path}'")
# List of essential import files and their paths
# Using a dictionary for easier checking and error messaging
required_files = {
"BAM Segments File": self.bam_segments_file,
"GFA Header File": self.header_gfa_file,
"GFA Stats File": self.stats_gfa_file,
"Samples-Chromosomes Map": self.samples_chrom_path,
}
required_dirs = {
"BED Files Directory": self.bed_path,
"BAM Files Directory": self.bam_path, # Though bam_segments_file is primary
}
# Check if the main import directory exists. If not, parsing is needed.
if not self.import_path.exists():
self.logger.info(
f"GraTools import directory '{self.import_path.name}' not found. "
f"This seems to be the first run for GFA '{self.gfa_path.name}'. "
"Proceeding with GFA parsing to create import."
)
self._run_gfa_parsing_and_importing()
return # After parsing, assume files are created. Re-check can be added if needed.
# If import_path exists, validate its contents
missing_or_empty_files = []
for name, path in required_files.items():
if not path or not path.exists() or path.stat().st_size == 0:
missing_or_empty_files.append(f"{name} ('{path}')")
for name, path in required_dirs.items():
if not path or not path.exists() or not path.is_dir():
missing_or_empty_files.append(f"{name} ('{path}') - not a directory or missing")
if missing_or_empty_files:
files_txt = '\n- '.join(missing_or_empty_files)
self.logger.warning(
f"GraTools import directory '{self.import_path.name}' is incomplete. Missing/empty files/dirs:\n- "
f"{files_txt}. "
"\nAttempting to regenerate by re-parsing GFA."
)
try:
self.logger.info(f"Cleaning up potentially corrupted import directory: '{self.import_path}'")
rmtree(self.import_path) # Remove corrupted import to ensure clean parse
except OSError as e:
self.logger.error(f"Failed to clean up corrupted import directory '{self.import_path}': {e}. "
"Manual cleanup might be required.", exc_info=self.debug)
raise # Re-raise as this is a critical state
self._run_gfa_parsing_and_importing()
else:
self.logger.info(f"GraTools import files found and appear complete in '{self.import_path.name}'.")
def _run_gfa_parsing_and_importing(self) -> None:
"""Helper method to encapsulate the GFA parsing and importing call."""
try:
# The GFA class's __post_init__ handles its own full processing workflow.
# Pass relevant parameters like threads and import_links flag.
gfa_processor = GFA(
gfa_path=self.gfa_path,
threads=self.threads, # Pass effective threads
import_links=self.import_links, # Control link importing
logger=self.logger, # Pass down the configured logger
disable_progress_flag=self.disable_progress_flag # Pass down progress bar control
)
# At this point, GFA.__post_init__ -> GFA.run() has completed.
self.logger.info(f"GFA parsing and importing completed successfully. Import at: '{self.import_path}'")
except Exception as e:
self.logger.error(
f"Error during GFA parsing and importing for '{self.gfa_path.name}': {e}\n"
"Attempting to clean up any partially created GraTools import files...",
exc_info=self.debug
)
if self.import_path and self.import_path.exists():
try:
rmtree(self.import_path)
self.logger.info(f"Successfully cleaned up import directory: '{self.import_path}'")
except OSError as cleanup_err:
self.logger.error(
f"Failed to auto-clean import directory '{self.import_path}' after error: {cleanup_err}. "
"Manual cleanup may be required.", exc_info=self.debug
)
raise # Re-raise the original exception to halt execution
# Methods for checking sample/chromosome presence (helper methods)
def _check_sample_in_gfa(self, sample_name: str) -> None:
"""Checks if a sample name exists in the loaded GFA data (dict_samples_chrom)."""
if sample_name not in self.dict_samples_chrom:
available_samples_preview = list(self.dict_samples_chrom.keys())[:20]
preview_str = ", ".join(available_samples_preview) + ("..." if len(self.dict_samples_chrom) > 20 else "")
self.logger.error(
f"Sample '{sample_name}' not found in GFA data (derived from '{self.gfa_path.name}'). "
f"Available samples start with: {preview_str if preview_str else 'None found'}. "
"Ensure sample name is correct and GFA import is up-to-date."
)
raise ValueError(f"Sample '{sample_name}' not found in GFA data.")
def _check_chromosome_in_gfa(self, sample_name: str, chromosome_name: str) -> None:
"""Checks if a chromosome exists for a given sample in the loaded GFA data."""
# Assumes _check_sample_in_gfa was called or sample_name is known to exist.
if chromosome_name not in self.dict_samples_chrom.get(sample_name, {}):
available_chroms_preview = list(self.dict_samples_chrom.get(sample_name, {}).keys())[:20]
preview_str = ", ".join(available_chroms_preview) + (
"..." if len(self.dict_samples_chrom.get(sample_name, {})) > 20 else "")
self.logger.error(
f"Chromosome '{chromosome_name}' not found for sample '{sample_name}' in GFA data. "
f"Available chromosomes for this sample start with: {preview_str if preview_str else 'None found'}."
)
raise ValueError(f"Chromosome '{chromosome_name}' not found for sample '{sample_name}'.")
def _load_samples_file(self, samples_file_path: Path) -> List[str]:
"""
Reads a list of sample names from a file, one sample name per line.
Validates each sample against GFA data.
Returns:
List[str]: List of valid sample names from the file.
"""
if not samples_file_path.is_file():
self.logger.error(f"Samples file not found: {samples_file_path}")
raise FileNotFoundError(f"Samples file not found: {samples_file_path}")
valid_samples_from_file = []
try:
with open(samples_file_path, "r", encoding="utf-8") as f:
for line_num, line_content in enumerate(f, 1):
sample_name_in_file = line_content.strip()
if not sample_name_in_file: continue # Skip empty lines
try:
self._check_sample_in_gfa(sample_name_in_file) # Validates against GFA data
valid_samples_from_file.append(sample_name_in_file)
except ValueError: # Raised by _check_sample_in_gfa if not found
self.logger.warning(
f"Sample '{sample_name_in_file}' from file '{samples_file_path.name}' (line {line_num}) "
"not found in GFA data. It will be ignored.")
except IOError as e:
self.logger.error(f"Error reading samples file '{samples_file_path}': {e}", exc_info=self.debug)
if not valid_samples_from_file:
self.logger.warning(f"No valid samples loaded from file '{samples_file_path.name}'.")
return valid_samples_from_file
def _get_samples_for_processing(self, samples_list_path: Optional[Path], all_samples_flag: bool) -> List[str]:
"""
Determines the list of samples to process based on input flags.
- If `samples_list_path` is provided, loads samples from that file.
- Else if `all_samples_flag` is True, returns all samples from GFA (excluding query_sample_name).
- Else (neither provided), returns an empty list.
"""
samples_to_process: List[str] = []
if samples_list_path:
self.logger.info(f"Loading samples to process from file: '{samples_list_path.name}'")
samples_to_process = self._load_samples_file(samples_list_path)
elif all_samples_flag:
self.logger.info("Processing all available samples from the GFA (excluding query sample if set).")
samples_to_process = list(self.dict_samples_chrom.keys())
if self.sample_name_query and self.sample_name_query in samples_to_process:
# Typically, the query sample is handled separately or first.
# If "all_samples" means "all *other* samples", then remove query sample.
samples_to_process.remove(self.sample_name_query)
self.logger.debug(f"Query sample '{self.sample_name_query}' included in 'all_samples' list.")
else:
self.logger.info("No specific samples file provided and 'all_samples' flag is false. "
"Only the query sample (if set) will be processed in primary operations.")
# samples_to_process remains empty.
return samples_to_process
# Public interface / GFA data access methods
[docs]
def get_gfa_statistics_df(self) -> Optional[pd.DataFrame]:
"""
Loads GFA statistics from the pre-computed statistics file into a pandas DataFrame.
Returns:
Optional[pd.DataFrame]: DataFrame with GFA statistics, or None if file not found/readable.
"""
if not self.stats_gfa_file or not self.stats_gfa_file.exists():
self.logger.error(f"GFA statistics file '{self.stats_gfa_file}' not found. Run importing first.")
return None
try:
df_stats = pd.read_csv(self.stats_gfa_file, sep="\t", header=0)
self.logger.debug(f"Successfully loaded GFA statistics from '{self.stats_gfa_file.name}'.")
return df_stats
except pd.errors.EmptyDataError:
self.logger.warning(f"GFA statistics file '{self.stats_gfa_file.name}' is empty.")
return None
except Exception as e: # Catch other pandas read_csv errors
self.logger.error(f"Error reading GFA statistics file '{self.stats_gfa_file}': {e}", exc_info=self.debug)
return None
[docs]
def save_gfa_statistics(self) -> None: # Renamed
"""Saves the chromosome summary per sample to a CSV file."""
if not self.works_path:
self.logger.error("Working directory (works_path) not set. Cannot save chromosome summary.")
return
df_stats = self.get_gfa_statistics_df()
if df_stats is None or df_stats.empty:
self.logger.info("No statistics to save.")
return
output_file = self.works_path / f"{self.gfa_name or 'gfa'}_statistics.csv"
self.logger.info(f"Saving statistics to: {output_file}")
try:
df_stats.to_csv(output_file, index=False, sep="\t")
except IOError as e:
self.logger.error(f"Failed to save statistics CSV to '{output_file}': {e}", exc_info=self.debug)
[docs]
def display_gfa_statistics(self, by_category: bool = False) -> None:
"""
Displays GFA statistics in a Rich Table, either categorized or as a single table.
Parameters:
by_category (bool, optional): If True, display stats in separate tables per category.
If False, display in a single comprehensive table.
Defaults to False.
"""
df_stats = self.get_gfa_statistics_df()
if df_stats is None or df_stats.empty:
self.logger.warning("No GFA statistics data to display.")
shared_console.print("[yellow]No GFA statistics data available to display.[/yellow]")
return
# Calculate summary statistics
# total_segments = df_stats.loc[df_stats["Metric"] == "Total segments", "Value"].values[0]
# total_bases = df_stats.loc[df_stats["Metric"] == "Total bases", "Value"].values[0]
# avg_segment_length = float(df_stats.loc[df_stats["Metric"] == "Average segment length", "Value"].values[0])
num_samples = len(self.available_sample_names)
# Summary
shared_console.rule("[bold green]Summary[/bold green]")
shared_console.print(f"[cyan]GFA File:[/cyan] {self.gfa_name or 'N/A'}")
shared_console.print(f"[cyan]Total samples:[/cyan] {num_samples:,}")
# shared_console.print(f"[cyan]Total segments:[/cyan] {total_segments:,}")
# shared_console.print(f"[cyan]Total bases:[/cyan] {total_bases:,}")
# shared_console.print(f"[cyan]Average segment length:[/cyan] {avg_segment_length:,.2f} bp")
shared_console.print(f"\n--- GFA Statistics for: {self.gfa_name or 'N/A'} ---")
if by_category:
categories = df_stats["Category"].unique()
for category_name in categories:
df_category_subset = df_stats[df_stats["Category"] == category_name]
table = Table(title=f"{category_name} Statistics", box=box.ROUNDED, show_header=True,
header_style="bold magenta")
table.add_column("Metric", justify="left", style="cyan", no_wrap=True)
table.add_column("Value", justify="left", style="green")
for _idx, row in df_category_subset.iterrows():
table.add_row(str(row["Metric"]), str(row["Value"]))
shared_console.print(table)
else: # Single table
table = Table(title="Comprehensive GFA Statistics", box=box.ROUNDED, show_header=True,
header_style="bold magenta")
table.add_column("Category", justify="left", style="blue", no_wrap=True)
table.add_column("Metric", justify="left", style="cyan", no_wrap=True)
table.add_column("Value", justify="left", style="green") # Right align for numbers
for _idx, row in df_stats.iterrows():
table.add_row(str(row["Category"]), str(row["Metric"]), str(row["Value"]))
shared_console.print(table)
shared_console.print("-----------------------------------------------------\n")
def _get_raw_chromosome_data_df(self) -> Optional[pd.DataFrame]:
"""
Loads raw chromosome data (sample, chrom, start, end) from `samples_chrom.txt`
into a DataFrame. Uses an internal cache.
Returns:
Optional[pd.DataFrame]: DataFrame with columns ["SAMPLES", "CHROMOSOMES_LIST", "START", "END"],
or None if data cannot be loaded.
"""
if self._cached_chromosome_data is not None:
return self._cached_chromosome_data
if not self.samples_chrom_path or not self.samples_chrom_path.exists():
self.logger.error(f"Samples-chromosomes map file '{self.samples_chrom_path}' not found.")
return None
try:
df = pd.read_csv(
self.samples_chrom_path,
sep="\t",
header=None, # No header in file
names=["SAMPLES", "CHROMOSOMES_LIST", "START_FRAG", "END_FRAG"], # Use distinct names
)
# Convert fragment coordinates to int, handling potential errors
df["START_FRAG"] = pd.to_numeric(df["START_FRAG"], errors='coerce')
df["END_FRAG"] = pd.to_numeric(df["END_FRAG"], errors='coerce')
df.dropna(subset=["START_FRAG", "END_FRAG"], inplace=True) # Remove rows where conversion failed
df = df.drop_duplicates() # Or specify subset if needed
# Apply a multi-level sort for samples, chromosomes, and fragment start position
natsort_key = natsort_keygen()
df['SAMPLES_SORT'] = df['SAMPLES'].apply(natsort_key)
df['CHROMOSOMES_SORT'] = df['CHROMOSOMES_LIST'].apply(natsort_key)
df = df.sort_values(
by=['SAMPLES_SORT', 'CHROMOSOMES_SORT', 'START_FRAG'],
kind='mergesort' # `mergesort` is stable, which can be useful for preserving existing order
).drop(columns=['SAMPLES_SORT', 'CHROMOSOMES_SORT']).reset_index(drop=True)
self._cached_chromosome_data = df
return df
except pd.errors.EmptyDataError:
self.logger.error(f"Samples-chromosomes map file '{self.samples_chrom_path.name}' is empty.")
return None
except Exception as e:
self.logger.error(f"Error reading samples-chromosomes map file '{self.samples_chrom_path}': {e}",
exc_info=self.debug)
return None
[docs]
def get_chromosome_size(self, sample_name: str, chromosome_name: str) -> Optional[int]:
"""
Gets the maximum end position (size) of a given chromosome for a specific sample.
This represents the extent of the chromosome as defined by walk fragments in the GFA.
Parameters:
sample_name (str): The name of the sample.
chromosome_name (str): The name of the chromosome.
Returns:
Optional[int]: The size of the chromosome (max end position of its fragments),
or None if sample/chromosome not found or no valid fragments.
"""
df_chrom_data = self._get_raw_chromosome_data_df()
if df_chrom_data is None or df_chrom_data.empty:
self.logger.warning(
f"Cannot get chromosome size for '{sample_name}/{chromosome_name}': No chromosome data loaded.")
return None
# Filter for the specific sample and chromosome
# Ensure column names match those used in _get_raw_chromosome_data_df
filtered_df = df_chrom_data[
(df_chrom_data["SAMPLES"] == sample_name) &
(df_chrom_data["CHROMOSOMES_LIST"] == chromosome_name)
]
if filtered_df.empty:
self.logger.debug(f"No fragments found for chromosome '{chromosome_name}' in sample '{sample_name}'.")
return None
# Max of 'END_FRAG' gives the chromosome size as per GFA walks
max_end_pos = filtered_df["END_FRAG"].max()
return int(max_end_pos) if pd.notna(max_end_pos) else None
@property
def available_sample_names(self) -> List[str]:
"""
Retrieves a sorted list of unique sample names present in the GFA data.
Returns:
List[str]: Sorted list of unique sample names.
"""
if not self.dict_samples_chrom: # If not populated by _read_samples_chrom_data
self.logger.warning("dict_samples_chrom is empty. Cannot retrieve sample names.")
return []
return natsorted(list(self.dict_samples_chrom.keys()))
[docs]
def display_available_sample_names(self) -> None:
"""Displays available sample names in a Rich Table."""
sample_names_list = self.available_sample_names
if not sample_names_list:
shared_console.print("[yellow]No sample names available to display.[/yellow]")
return
# Summary
shared_console.rule("[bold green]Summary[/bold green]")
shared_console.print(f"[cyan]Total samples in GFA:[/cyan] {len(sample_names_list):,}")
table = Table(title=f"Available Samples in GFA: {self.gfa_name or 'N/A'}", box=box.ROUNDED)
table.add_column("Sample Name", style="cyan", no_wrap=True)
for sample in sample_names_list:
table.add_row(sample)
shared_console.print(table)
[docs]
def save_available_sample_names(self) -> None:
"""Saves the list of available sample names to a CSV file."""
if not self.works_path:
self.logger.error("Working directory (works_path) not set. Cannot save sample names.")
return
output_file = self.works_path / f"{self.gfa_name or 'gfa'}_samples.txt" # Use .txt for simple list
sample_names_list = self.available_sample_names
if not sample_names_list:
self.logger.error("No sample names to save.")
return
self.logger.info(f"Saving list of {len(sample_names_list)} available sample names to: {output_file}")
try:
with open(output_file, "w", encoding="utf-8") as f_out:
for sample_name in sample_names_list:
f_out.write(f"{sample_name}\n")
except IOError as e:
self.logger.error(f"Failed to save sample names list to '{output_file}': {e}", exc_info=self.debug)
[docs]
def get_chromosomes_summary_by_sample_df(self) -> Optional[pd.DataFrame]: # Renamed
"""
Generates a DataFrame summarizing chromosomes per sample.
Includes sample name, a comma-separated list of unique chromosome names,
and the count of unique chromosomes for that sample.
Returns:
Optional[pd.DataFrame]: DataFrame with columns ["SAMPLES", "CHROMOSOMES_LIST", "NUM_UNIQUE_CHROMOSOMES"],
or None if data cannot be loaded/processed.
"""
df_raw_data = self._get_raw_chromosome_data_df()
if df_raw_data is None or df_raw_data.empty:
self.logger.warning("Cannot generate chromosome summary: No raw chromosome data available.")
return None
try:
# Group by SAMPLES, then aggregate unique CHROMOSOMES_LIST
grouped_df = (
df_raw_data.groupby("SAMPLES")["CHROMOSOMES_LIST"]
.apply(lambda x: ", ".join(natsorted(list(set(x))))) # Get unique, sorted chroms
.reset_index()
)
# Count the number of unique chromosomes per sample
grouped_df["NUM_UNIQUE_CHROMOSOMES"] = grouped_df["CHROMOSOMES_LIST"].apply(
lambda x: len(x.split(", ")) if x else 0
)
return grouped_df
except Exception as e:
self.logger.error(f"Error generating chromosomes summary by sample: {e}", exc_info=self.debug)
return None
[docs]
def save_chromosomes_summary_by_sample(self) -> None:
"""Saves the chromosome summary per sample to a CSV file."""
if not self.works_path:
self.logger.error("Working directory (works_path) not set. Cannot save chromosome summary.")
return
df_summary = self.get_chromosomes_summary_by_sample_df()
if df_summary is None or df_summary.empty:
self.logger.info("No chromosome summary data to save.")
return
output_file = self.works_path / f"{self.gfa_name or 'gfa'}_chromosomes_summary_by_sample.csv"
self.logger.info(f"Saving chromosome summary by sample to: {output_file}")
try:
df_summary.to_csv(output_file, index=False, sep="\t")
except IOError as e:
self.logger.error(f"Failed to save chromosome summary CSV to '{output_file}': {e}", exc_info=self.debug)
[docs]
def save_full_chromosome_fragment_data(self) -> None: # Renamed
"""Saves the raw chromosome fragment data (sample, chrom, start, end) to a CSV file."""
df_full_data = self._get_raw_chromosome_data_df() # Uses cache if available
if df_full_data is None or df_full_data.empty:
self.logger.info("No full chromosome fragment data to save.")
return
output_file = self.works_path / f"{self.gfa_name or 'gfa'}_full_chromosome_fragment_data.csv"
self.logger.info(f"Saving full chromosome fragment data to: {output_file}")
try:
# Save with original column names used in _get_raw_chromosome_data_df
df_full_data.to_csv(
output_file,
index=False,
sep="\t",
header=["SAMPLES", "CHROMOSOMES_LIST", "START_FRAG", "END_FRAG"] # Match DF columns
)
except IOError as e:
self.logger.error(f"Failed to save full chromosome data CSV to '{output_file}': {e}", exc_info=self.debug)
[docs]
def display_chromosomes_summary(self) -> None:
"""Displays the chromosome summary per sample in a Rich Table."""
df_summary = self.get_chromosomes_summary_by_sample_df()
if df_summary is None or df_summary.empty:
shared_console.print("[yellow]No chromosome summary data available to display.[/yellow]")
return
num_samples = len(self.available_sample_names)
min_chroms = df_summary["NUM_UNIQUE_CHROMOSOMES"].min() if not df_summary.empty else 0
max_chroms = df_summary["NUM_UNIQUE_CHROMOSOMES"].max() if not df_summary.empty else 0
# Summary
shared_console.rule("[bold green]Summary[/bold green]")
shared_console.print(f"[cyan]Total samples:[/cyan] {num_samples:,}")
shared_console.print(f"[cyan]Min chromosomes per sample:[/cyan] {min_chroms:,}")
shared_console.print(f"[cyan]Max chromosomes per sample:[/cyan] {max_chroms:,}")
table = Table(
title=(f"{num_samples} Samples, {min_chroms}-{max_chroms} Unique Chromosomes/Sample "
f"(GFA: {self.gfa_name or 'N/A'})"),
show_lines=True, box=box.ROUNDED, header_style="bold magenta"
)
table.add_column("Sample Name", style="cyan", no_wrap=True)
table.add_column("Unique Chromosomes (Comma-separated)", style="blue")
table.add_column("Number of Unique Chromosomes", justify="left", style="green")
for _idx, row in df_summary.iterrows():
table.add_row(
str(row["SAMPLES"]),
str(row["CHROMOSOMES_LIST"]),
str(row["NUM_UNIQUE_CHROMOSOMES"])
)
shared_console.print(table)
[docs]
def display_full_chromosome_fragment_data(self) -> None:
"""Displays the full chromosome fragment data in a Rich Table, grouped by sample."""
df_full_data = self._get_raw_chromosome_data_df()
if df_full_data is None or df_full_data.empty:
shared_console.print("[yellow]No full chromosome fragment data available to display.[/yellow]")
return
num_samples = len(self.available_sample_names)
# Min/max chroms per sample (can be taken from summary if available)
# For this display, just show total samples. Title matches original.
df_summary = self.get_chromosomes_summary_by_sample_df() # To get min/max for title
min_chroms_title = df_summary[
"NUM_UNIQUE_CHROMOSOMES"].min() if df_summary is not None and not df_summary.empty else "N/A"
max_chroms_title = df_summary[
"NUM_UNIQUE_CHROMOSOMES"].max() if df_summary is not None and not df_summary.empty else "N/A"
# Calculate summary statistics
total_fragments = len(df_full_data)
total_samples = len(df_full_data['SAMPLES'].unique())
total_chromosomes = len(df_full_data['CHROMOSOMES_LIST'].unique())
min_fragment_length = df_full_data['END_FRAG'].min() - df_full_data['START_FRAG'].min()
max_fragment_length = df_full_data['END_FRAG'].max() - df_full_data['START_FRAG'].max()
avg_fragment_length = (df_full_data['END_FRAG'] - df_full_data['START_FRAG']).mean()
# Summary
shared_console.rule("[bold green]Summary[/bold green]")
shared_console.print(f"[cyan]Total samples:[/cyan] {total_samples:,}")
shared_console.print(f"[cyan]Total unique chromosomes:[/cyan] {total_chromosomes:,}")
shared_console.print(f"[cyan]Total fragments:[/cyan] {total_fragments:,}")
shared_console.print(f"[cyan]Min fragment length:[/cyan] {min_fragment_length:,} bp")
shared_console.print(f"[cyan]Max fragment length:[/cyan] {max_fragment_length:,} bp")
shared_console.print(f"[cyan]Average fragment length:[/cyan] {avg_fragment_length:,.2f} bp")
table = Table(
title=(f"{num_samples} Samples, {min_chroms_title}-{max_chroms_title} Chromosomes/Sample "
f"(GFA: {self.gfa_name or 'N/A'}) - Full Fragment List"),
show_lines=False, box=box.ROUNDED, header_style="bold magenta"
)
table.add_column("Sample Name", style="cyan", no_wrap=True)
table.add_column("Chromosome Fragment Name", style="blue")
table.add_column("Fragment Start", justify="right", style="green")
table.add_column("Fragment End", justify="right", style="green")
# Trier la DataFrame par échantillon et par fragment de chromosome
df_full_data = df_full_data.sort_values(by=['SAMPLES', 'CHROMOSOMES_LIST'])
last_sample_name = None # Variable pour mémoriser le dernier échantillon affiché
# 3. Itère sur chaque ligne du DataFrame maintenant parfaitement trié.
for _, row in df_full_data.iterrows():
current_sample_name = row['SAMPLES']
sample_name_display = ""
# Détecte un changement de groupe (d'échantillon)
if current_sample_name != last_sample_name:
# Ajoute une ligne de séparation pour le groupe qui vient de se terminer
if last_sample_name is not None:
table.add_section()
# Prépare le nom du nouvel échantillon pour l'affichage
sample_name_display = str(current_sample_name)
last_sample_name = current_sample_name
# Formatage des nombres
start_frag_formatted = f'{int(row["START_FRAG"]):,}'
end_frag_formatted = f'{int(row["END_FRAG"]):,}'
table.add_row(
sample_name_display,
str(row["CHROMOSOMES_LIST"]),
start_frag_formatted,
end_frag_formatted
)
shared_console.print(table)
# --- SubGraph Extraction and Processing ---
[docs]
def get_subgraph(self, samples_list_path: Optional[Path] = None,
all_samples_flag: bool = False) -> None:
"""
Extracts a subgraph based on a query region and optionally for other specified samples.
Manages SubGraph object creation, processing, and GFA/FASTA file generation.
Parameters:
samples_list_path (Optional[Path]): Path to a file containing a list of additional
samples to process (one per line).
all_samples_flag (bool): If True and `samples_list_path` is not given, process all
samples found in the GFA (relative to the query region).
"""
# Determine which other samples (besides query sample) to process
# The query sample (self.sample_name_query) is processed first regardless.
other_samples_to_process = self._get_samples_for_processing(samples_list_path, all_samples_flag)
self.dict_gfa_graph_object = {} # Reset for this extraction run
if not (self.sample_name_query and self.chromosome_query and \
self.start_query is not None and self.stop_query is not None):
self.logger.error("Cannot extract subgraph: Query sample, chromosome, start, or stop not defined.")
return
self.logger.info(
f"Extracting subgraph for query: Sample='{self.sample_name_query}', "
f"Region='{self.chromosome_query}:{self.start_query}-{self.stop_query}', MergeDist='{self.merge or 0}bp'."
)
# Process the primary query sample
query_bed_file = self.bed_path / f"{self.sample_name_query}.bed"
if not query_bed_file.exists():
self.logger.error(
f"BED file for query sample '{self.sample_name_query}' not found at '{query_bed_file}'. Cannot proceed.")
return
# Instantiate SubGraph for the query sample
self.sub_graph_query = SubGraph(
bam_path=self.bam_segments_file, # Path to the main segments BAM
bed_path=query_bed_file,
sample_name=self.sample_name_query, # This SubGraph instance is for sample_name_query
sample_name_query=self.sample_name_query, # Passed for context
chromosome_query=self.chromosome_query,
start_query=self.start_query,
stop_query=self.stop_query,
works_path=self.works_path, # For saving intermediate files like regions.csv
merge=self.merge,
logger=self.logger, # Pass down logger
build_fasta_flag=self.build_fasta_flag
)
self.dict_gfa_graph_object[self.sample_name_query] = self.sub_graph_query
# Perform core SubGraph operations for the query sample
self.sub_graph_query.compute_intersection() # Intersects query region with sample's BED
self.sub_graph_query.build_walks() # Builds W and L lines from intersections
self.sub_graph_query.build_segments() # Fetches S lines for segments in walks
if self.build_fasta_flag:
self.sub_graph_query.build_fasta() # Generates FASTA sequences
self.logger.info(
f"Finished processing query sample '{self.sample_name_query}'. "
f"Now processing {len(other_samples_to_process)} additional sample(s) for the same region."
)
# Process other samples in parallel if any
if other_samples_to_process:
self._parallelize_subgraph_processing_for_samples(other_samples_to_process)
self.logger.debug(f"Finished processing {len(other_samples_to_process)} additional sample(s).")
else:
self.logger.info("No additional unique samples to process.")
# Consolidate results after all samples are processed
# `_is_subgraph_extraction_main_call()` checks if 'get_subgraph' was the primary CLI command.
# This is a bit fragile; consider a more direct flag if this distinction is important.
if self._is_subgraph_extraction_main_call():
self.logger.debug("Consolidating results after all samples are processed.")
self.concatenate_and_generate_subgfa_file() # Writes combined GFA
self.logger.debug("_merge_and_save_region_dataframes")
self._merge_and_save_region_dataframes() # Merges per-sample region CSVs
if self.build_fasta_flag:
self.logger.debug("Consolidating FASTA sequences after all samples are processed.")
self.generate_combined_fasta_file() # Writes combined FASTA
def _is_subgraph_extraction_main_call(self) -> bool:
"""
Heuristic to check if 'get_subgraph' (or similar) was likely the
main command executed via CLI, influencing whether to write combined GFA.
"""
# This is a simplification. A more robust way would be to pass a flag
# from the CLI dispatcher indicating the primary command.
cli_command_str = " ".join(sys.argv).lower() # Get command line arguments
return "subgraph" in cli_command_str
def _merge_and_save_region_dataframes(self) -> None:
"""
Merges temporary region CSV files (created by SubGraph instances) into
consolidated CSV files for all processed samples. Cleans up temporary files.
"""
self.logger.info("Merging per-sample region DataFrames...")
# Define output paths for consolidated CSVs
# Suffix should be correctly set by _finalize_suffix before this point.
final_regions_csv = self.works_path / f"{self.gfa_name or 'gfa'}_regions{self.suffix}.csv"
final_regions_collapse_csv = self.works_path / f"{self.gfa_name or 'gfa'}_regions_collapsed_d{self.merge or 0}{self.suffix}.csv"
temp_regions_pattern = self.works_path.glob(
f"*_regions_all.csv") # From SubGraph.get_chr_pos save
temp_regions_collapse_pattern = self.works_path.glob(
f"*_regions_collapsed_d{self.merge or 0}.csv")
all_regions_dfs: List[pd.DataFrame] = []
all_regions_collapse_dfs: List[pd.DataFrame] = []
temp_files_to_delete: List[Path] = []
header = [
"chromosome",
"start",
"stop",
"segment_id_collapsed",
"fragment_collapsed",
"haplotype_index_distinct",
]
for temp_csv_path in temp_regions_pattern:
try:
df = pd.read_csv(temp_csv_path, sep="\t", header=None, names=header)
if not df.empty: all_regions_dfs.append(df)
temp_files_to_delete.append(temp_csv_path)
except Exception as e:
self.logger.warning(f"Failed to read or process temp regions CSV '{temp_csv_path.name}': {e}",
exc_info=self.debug)
for temp_csv_path_collapse in temp_regions_collapse_pattern:
try:
df_collapse = pd.read_csv(temp_csv_path_collapse, sep="\t", header=None, names=header)
if not df_collapse.empty: all_regions_collapse_dfs.append(df_collapse)
temp_files_to_delete.append(temp_csv_path_collapse) # Also delete collapse temp files
except Exception as e:
self.logger.warning(
f"Failed to read or process temp collapsed regions CSV '{temp_csv_path_collapse.name}': {e}",
exc_info=self.debug)
# Concatenate and save if DataFrames were collected
if all_regions_dfs:
concatenated_df_all = pd.concat(all_regions_dfs, axis=0, ignore_index=True)
# Sort by chromosome and start position (assuming first two columns)
if not concatenated_df_all.empty and len(concatenated_df_all.columns) >= 2:
concatenated_df_all.sort_values(
by=[concatenated_df_all.columns[0], concatenated_df_all.columns[1]], inplace=True
)
try:
concatenated_df_all.to_csv(final_regions_csv, index=False, sep="\t")
self.logger.info(f"Saved merged regions data to: '{final_regions_csv.name}'")
except IOError as e:
self.logger.error(f"Failed to save merged regions CSV '{final_regions_csv.name}': {e}",
exc_info=self.debug)
else:
self.logger.info("No 'all regions' DataFrames to merge.")
if all_regions_collapse_dfs:
concatenated_df_collapse = pd.concat(all_regions_collapse_dfs, axis=0, ignore_index=True)
if not concatenated_df_collapse.empty and len(concatenated_df_collapse.columns) >= 2:
concatenated_df_collapse.sort_values(
by=[concatenated_df_collapse.columns[0], concatenated_df_collapse.columns[1]], inplace=True
)
try:
concatenated_df_collapse.to_csv(final_regions_collapse_csv, index=False, sep="\t")
self.logger.info(f"Saved merged collapsed regions data to: '{final_regions_collapse_csv.name}'")
except IOError as e:
self.logger.error(
f"Failed to save merged collapsed regions CSV '{final_regions_collapse_csv.name}': {e}",
exc_info=self.debug)
else:
self.logger.info("No 'collapsed regions' DataFrames to merge.")
# Clean up temporary CSV files
for temp_file in temp_files_to_delete:
try:
temp_file.unlink(missing_ok=True) # missing_ok if file might have been moved/deleted
self.logger.debug(f"Deleted temporary regions CSV: '{temp_file.name}'")
except OSError as e:
self.logger.warning(f"Failed to delete temporary regions CSV '{temp_file.name}': {e}",
exc_info=self.debug)
def _process_single_sample_for_subgraph(self, sample_name: str,
shared_progress_dict: Dict, # For multiprocessing progress
task_id_rich: Any) -> Tuple[
str, Optional[SubGraph]]: # task_id_rich from Rich Progress
"""
Processes a single sample for subgraph extraction. This method is designed
to be called in parallel (e.g., by ProcessPoolExecutor).
It creates a SubGraph instance for the sample, populates it based on the
main query's segments, and optionally builds FASTA.
Parameters:
sample_name (str): The name of the sample to process.
shared_progress_dict (Dict): A manager.dict() for inter-process progress reporting.
task_id_rich (Any): The Rich Progress TaskID for this sample.
Returns:
Tuple[str, Optional[SubGraph]]: Sample name and the processed SubGraph object (or None on error).
"""
self.logger.info(f"Starting subgraph processing for sample: '{sample_name}'")
# Initial progress update
if shared_progress_dict is not None and task_id_rich is not None:
shared_progress_dict[task_id_rich] = {
"progress": 1, "total": 7, # Assuming 7 steps as in original SubGraph.get_chr_pos
"description": f"[cyan]'{sample_name}': Starting...",
}
sample_bed_file = self.bed_path / f"{sample_name}.bed"
if not sample_bed_file.exists():
self.logger.error(f"BED file for sample '{sample_name}' not found at '{sample_bed_file}'. Cannot process.")
if shared_progress_dict is not None and task_id_rich is not None:
shared_progress_dict[task_id_rich] = {"progress": 7, "total": 7,
"description": f"[red]'{sample_name}': BED missing"}
return sample_name, None
if not self.sub_graph_query: # Query SubGraph must exist and be processed
self.logger.error(
"Main query SubGraph (self.sub_graph_query) not initialized. Cannot process other samples.")
if shared_progress_dict is not None and task_id_rich is not None:
shared_progress_dict[task_id_rich] = {"progress": 7, "total": 7,
"description": f"[red]'{sample_name}': Query SubGraph missing"}
return sample_name, None
# Create SubGraph instance for this specific sample.
# It inherits segment_id_set and other relevant attributes from the main query's SubGraph.
# This implies other samples' subgraphs are built relative to the segments found in the query sample's region.
current_sample_subgraph = SubGraph(
bam_path=self.bam_segments_file,
bed_path=sample_bed_file,
sample_name=sample_name, # Current sample being processed
# Parameters inherited from the main query context:
segment_id_set=frozenset(self.sub_graph_query.segment_id_set),
# Crucial: copy to avoid interference if modified
dict_segments_samples=self.sub_graph_query.dict_segments_samples, # Also copy
sample_name_query=self.sample_name_query,
chromosome_query=self.chromosome_query,
start_query=self.start_query,
stop_query=self.stop_query,
offset_first=self.sub_graph_query.offset_first, # Offsets determined by query sample
offset_last=self.sub_graph_query.offset_last,
segment_id_first_query=self.sub_graph_query.segment_id_first_query,
segment_id_last_query=self.sub_graph_query.segment_id_last_query,
works_path=self.works_path,
merge=self.merge,
logger=self.logger, # Pass logger
build_fasta_flag=self.build_fasta_flag
)
try:
# SubGraph.get_chr_pos does the heavy lifting:
# 1. Finds regions in `current_sample_subgraph` corresponding to `segment_id_set` from query.
# 2. Calls compute_intersection, build_walks, build_segments for these found regions.
current_sample_subgraph.get_chr_pos(progress_dict=shared_progress_dict, task_id=task_id_rich)
if self.build_fasta_flag:
self.logger.debug(f"'{sample_name}': Building FASTA sequences after subgraph processing.")
current_sample_subgraph.build_fasta()
if shared_progress_dict is not None and task_id_rich is not None: # Example of adding a step
shared_progress_dict[task_id_rich] = {"progress": 7, "total": 7,
"description": f"[green]'{sample_name}': FASTA done"}
self.logger.info(f"Successfully processed subgraph for sample: '{sample_name}'")
return sample_name, current_sample_subgraph
except Exception as e:
self.logger.error(f"Error processing subgraph for sample '{sample_name}': {e}", exc_info=self.debug)
if shared_progress_dict is not None and task_id_rich is not None:
shared_progress_dict[task_id_rich] = {"progress": 7, "total": 7,
"description": f"[red]'{sample_name}': Error - {e!s:.30}"}
return sample_name, None
def _parallelize_subgraph_processing_for_samples(self, samples_to_process: List[str]) -> None:
"""
Manages parallel processing of multiple samples for subgraph extraction using ProcessPoolExecutor.
Uses Rich Progress for visual feedback.
Parameters:
samples_to_process (List[str]): List of sample names to process.
"""
if not samples_to_process:
self.logger.info("No samples provided for parallel subgraph processing.")
return
num_workers = min(self.threads, len(samples_to_process)) # Don't use more workers than samples
if num_workers <= 0: num_workers = 1 # Ensure at least one worker
self.logger.info(
f"Parallelizing subgraph processing for {len(samples_to_process)} samples using {num_workers} worker(s).")
# Setup Rich Progress for parallel tasks
# `shared_console` should be a Rich Console instance accessible here.
overall_progress = Progress(
TextColumn("[bold blue]{task.description}"), BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%", "{task.completed}/{task.total} Samples",
TimeElapsedColumn(), TimeRemainingColumn(), refresh_per_second=0.1, # Slower refresh for overall
console=shared_console, transient=False # Keep overall progress visible
)
task_specific_progress = Progress( # For individual sample steps
TextColumn("[cyan]{task.description}"), BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%", "{task.completed}/{task.total} Steps",
TimeElapsedColumn(), refresh_per_second=0.2, # Faster refresh for task steps
console=shared_console, transient=False # Keep task details visible until all done
)
progress_group_layout = Group(
Panel(Group(overall_progress), title="Overall Sample Processing"),
Panel(Group(task_specific_progress), title="Per-Sample Progress")
)
# Multiprocessing Manager for shared dictionary (for progress reporting from processes)
with multiprocessing.Manager() as manager:
shared_progress_report_dict = manager.dict()
futures_map = {} # Map future to sample_name for easier result handling
with ProcessPoolExecutor(max_workers=num_workers) as executor, Live(progress_group_layout,
console=shared_console,
refresh_per_second=0.2, screen=False,
transient=False) as live_progress:
overall_task = overall_progress.add_task("Processing Samples...", total=len(samples_to_process))
for sample_name_iter in samples_to_process:
# Add a Rich Progress task for this specific sample's steps
rich_task_id = task_specific_progress.add_task(f"'{sample_name_iter}': Queued", total=7,
start=False) # Total steps from SubGraph
# Submit task to executor
future = executor.submit(
self._process_single_sample_for_subgraph,
sample_name_iter,
shared_progress_report_dict, # Pass manager dict
rich_task_id # Pass Rich TaskID
)
futures_map[future] = (sample_name_iter, rich_task_id)
# Monitor futures and update progress
num_completed_futures = 0
while num_completed_futures < len(samples_to_process):
# Update overall progress based on futures completed
num_completed_futures = sum(f.done() for f in futures_map.keys())
overall_progress.update(overall_task, completed=num_completed_futures)
# Update individual task progress from shared_progress_report_dict
for rich_tid, progress_data in list(
shared_progress_report_dict.items()): # list() for safe iteration if modified
task_specific_progress.update(
rich_tid,
completed=progress_data.get("progress", 0),
total=progress_data.get("total", 7), # Default total if not set
description=str(progress_data.get("description", "Processing..."))
)
if progress_data.get("progress", 0) == progress_data.get("total", 7): # Task is complete
task_specific_progress.update(rich_tid, visible=False) # Optionally hide completed tasks
#sleep(0.5) # Brief sleep to avoid busy-waiting and allow Rich to refresh
# Ensure overall progress is marked as fully complete
overall_progress.update(overall_task, completed=len(samples_to_process),
description="All samples processed.")
# Retrieve results and store SubGraph objects
for future, (sample_name_res, rich_task_id_res) in futures_map.items():
try:
_s_name, sub_graph_obj = future.result() # _s_name should match sample_name_res
if sub_graph_obj:
self.dict_gfa_graph_object[sample_name_res] = sub_graph_obj
task_specific_progress.update(rich_task_id_res,
description=f"[green]'{sample_name_res}': Completed & Stored",
completed=7, total=7, visible=False)
else:
self.logger.error(
f"Subgraph processing for sample '{sample_name_res}' returned no object (likely an error).")
task_specific_progress.update(rich_task_id_res,
description=f"[red]'{sample_name_res}': Failed (No object)",
completed=7, total=7, visible=True) # Keep error visible
except Exception as e:
self.logger.error(f"Exception retrieving result for sample '{sample_name_res}': {e}",
exc_info=self.debug)
task_specific_progress.update(rich_task_id_res,
description=f"[bold red]'{sample_name_res}': ERROR - {e!s:.30}",
completed=7, total=7, visible=True)
self.logger.info("End so close progress")
# Clear Rich progress bars after completion if transient was not enough or for manual cleanup
# overall_progress.stop()
# task_specific_progress.stop()
# use pybedtools cleanup to remove temporary files
cleanup()
self.logger.info("Parallel subgraph processing for all samples finished.")
[docs]
def concatenate_and_generate_subgfa_file(self) -> None:
"""
Concatenates GFA components (Header, Segments, Links, Walks) from all
processed SubGraph objects and writes them to a combined GFA file (gzipped).
"""
if not self.dict_gfa_graph_object:
self.logger.info("No SubGraph objects processed. Cannot generate combined GFA file.")
return
if not self.works_path or not self.header_gfa_file or not self.header_gfa_file.exists():
self.logger.error("Works_path or GFA header file not available. Cannot generate combined GFA.")
return
self.logger.info(f"Consolidating GFA data from {len(self.dict_gfa_graph_object)} processed SubGraphs.")
# Load original GFA header
try:
with open(self.header_gfa_file, "r", encoding="utf-8") as hf:
original_header_lines = [line.strip() for line in hf if line.strip()]
except IOError as e:
self.logger.error(f"Failed to read GFA header file '{self.header_gfa_file}': {e}", exc_info=self.debug)
original_header_lines = [f"H\tVN:Z:1.0\tNOTE:Z:Original header missing - {e!s:.50}"] # Fallback header
# Collect unique S, L, W lines from all SubGraph objects
# Using sets to ensure uniqueness across different subgraphs (e.g., same segment in multiple)
all_s_lines: Set[str] = set()
all_l_lines: Set[str] = set()
all_w_lines: Set[str] = set()
for sample_name, sub_graph_obj in self.dict_gfa_graph_object.items():
if sub_graph_obj: # Ensure object exists
all_s_lines.update(sub_graph_obj.gfa_segment_list)
all_l_lines.update(sub_graph_obj.gfa_link_list)
all_w_lines.update(sub_graph_obj.gfa_walk_list)
self.logger.debug(
f"Sample '{sample_name}' contributed {len(sub_graph_obj.gfa_walk_list)} walks to combined GFA.")
else:
self.logger.warning(f"SubGraph object for sample '{sample_name}' is None. Skipping its GFA data.")
# Define output path for the combined GFA file
# Suffix should be set and finalized by this point.
sub_gfa_output_path = self.works_path / f"{self.gfa_name or 'gfa'}_subgraph{self.suffix}.gfa.gz"
self.logger.info(f"Writing combined subgraph GFA to: '{sub_gfa_output_path.name}' "
f"({len(all_s_lines)} S, {len(all_l_lines)} L, {len(all_w_lines)} W lines)")
try:
# Use mtime=0 for reproducible gzipping if contents are identical
with gzip.GzipFile(sub_gfa_output_path, "wb", compresslevel=6, mtime=0) as f_gz:
# Write original header
for header_line in original_header_lines:
f_gz.write(f"{header_line}\n".encode("utf-8"))
# Add a new header line indicating this is a subgraph with specific suffix/query info
subgraph_info_header = f"H\tRS:Z:Subgraph_QueryContext={self.suffix.lstrip('_') if self.suffix else 'N/A'}"
f_gz.write(f"{subgraph_info_header}\n".encode("utf-8"))
# Write S, L, W lines, sorted for deterministic output
for s_line in sorted(list(all_s_lines)):
f_gz.write(f"{s_line}\n".encode("utf-8"))
for l_line in sorted(list(all_l_lines)):
f_gz.write(f"{l_line}\n".encode("utf-8"))
for w_line in sorted(list(all_w_lines)):
f_gz.write(f"{w_line}\n".encode("utf-8"))
self.logger.info(f"Successfully generated combined subgraph GFA: '{sub_gfa_output_path.name}'")
except IOError as e:
self.logger.error(f"Failed to write combined GFA file '{sub_gfa_output_path}': {e}", exc_info=self.debug)
[docs]
def generate_combined_fasta_file(self) -> None: # Renamed for clarity
"""
Generates a combined FASTA file from sequences collected in all processed SubGraph objects.
"""
if not self.build_fasta_flag: # Check if FASTA generation was enabled for subgraphs
self.logger.info("FASTA generation was not enabled (build_fasta_flag=False). Skipping combined FASTA.")
return
if not self.dict_gfa_graph_object:
self.logger.info("No SubGraph objects processed. Cannot generate combined FASTA file.")
return
if not self.works_path:
self.logger.error("Works_path not set. Cannot determine output path for combined FASTA.")
return
# Define output path for the combined FASTA file
combined_fasta_output_path = self.works_path / f"{self.gfa_name or 'gfa'}_subgraph{self.suffix}.fasta"
all_seq_records: List[SeqRecord] = []
for sample_name, sub_graph_obj in self.dict_gfa_graph_object.items():
if sub_graph_obj and sub_graph_obj.sequences_list:
all_seq_records.extend(sub_graph_obj.sequences_list)
self.logger.debug(
f"Collected {len(sub_graph_obj.sequences_list)} FASTA records from sample '{sample_name}'.")
elif sub_graph_obj:
self.logger.debug(f"Sample '{sample_name}' had a SubGraph object but no sequences_list content.")
else:
self.logger.warning(f"SubGraph object for sample '{sample_name}' is None. Skipping its FASTA data.")
if not all_seq_records:
self.logger.info("No FASTA sequences collected from SubGraphs. Combined FASTA file will not be generated.")
return
self.logger.info(
f"Writing {len(all_seq_records)} combined FASTA records to: '{combined_fasta_output_path.name}'")
try:
# The `flatten` utility was used before, assuming sequences_list might be a list of lists.
# If SubGraph.sequences_list is already a flat List[SeqRecord], flatten is not needed here.
# Current SubGraph.sequences_list seems to be List[SeqRecord].
# `all_seq_records` is already a flat list from extend.
with open(combined_fasta_output_path, "w", encoding="utf-8") as f_out:
SeqIO.write(all_seq_records, f_out, "fasta")
self.logger.info(f"Successfully generated combined FASTA file: '{combined_fasta_output_path.name}'")
except IOError as e:
self.logger.error(f"Failed to write combined FASTA file '{combined_fasta_output_path}': {e}",
exc_info=self.debug)
except Exception as e: # Catch other SeqIO errors
self.logger.error(f"Error during SeqIO.write for combined FASTA: {e}", exc_info=self.debug)
# --- BAM Analysis Methods (delegating to GratoolsBam) ---
def _get_bam_manager(self) -> Optional[GratoolsBam]:
"""Helper to instantiate GratoolsBam, ensuring necessary paths are set."""
if not self.bam_segments_file or not self.bam_segments_file.exists():
self.logger.error("Main BAM segments file not found. Cannot perform BAM-based analyses.")
return None
# Ensure works_path and gfa_name are passed if GratoolsBam uses them for output naming.
return GratoolsBam(
bam_path=self.bam_segments_file,
threads=self.threads,
suffix=self.suffix, # Pass current suffix for consistent naming
works_path=self.works_path,
gfa_name=self.gfa_name,
logger=self.logger # Pass logger
)
[docs]
def run_pan_ratio_analysis(self, input_as_number: bool,
shared_min: int, specific_max: Optional[int],
filter_len: int) -> None:
"""
Runs core/dispensable segment ratio analysis using GratoolsBam.
Parameters mirror those of GratoolsBam.pan_ratio.
"""
bam_manager = self._get_bam_manager()
if not bam_manager: return
num_gfa_samples = len(self.available_sample_names) # Get total samples for percentage calcs
if num_gfa_samples == 0:
self.logger.warning(
"No GFA samples found. Cannot accurately run core/dispensable ratio analysis if percentages are used.")
# Proceed if input_as_number is True, but percentages will be meaningless if nb_samples_gfa=0
self.logger.info("Delegating core/dispensable ratio analysis to GratoolsBam.")
bam_manager.pan_ratio(
nb_samples_gfa=num_gfa_samples,
input_as_number=input_as_number,
shared_min_cutoff=shared_min, # Pass with "_cutoff" to match GratoolsBam param
specific_max_cutoff=specific_max,
filter_min_len=filter_len
)
[docs]
def run_depth_nodes_statistics(self, filter_len: int) -> None:
"""Runs node depth statistics analysis using GratoolsBam."""
bam_manager = self._get_bam_manager()
if not bam_manager: return
num_gfa_samples = len(self.available_sample_names)
self.logger.info("Delegating node depth statistics analysis to GratoolsBam.")
bam_manager.depth_nodes_stat(
nb_samples_gfa=num_gfa_samples,
filter_min_len=filter_len
)
[docs]
def run_get_specific_groups_sample_analysis(self, sample_list_a_path: Optional[Path],
sample_list_b_path: Optional[Path],
filter_len: Optional[int],
output_csv: Optional[bool],
) -> None:
"""
Run get_specific_groups_sample and saves the result in a file
Args:
sample_list_a (list): List of samples to check for shared segments.
sample_list_b (list): List of samples to check for specific segments.
filter_len (int, optional): Minimum length of segments to be considered.
output_csv (bool) : output the segments in a csv file (if False, only print stats)
"""
bam_manager = self._get_bam_manager()
if not bam_manager: return
# Load sample lists from files, validating against GFA data
samples_A: List[str] = []
if sample_list_a_path:
samples_A = self._load_samples_file(sample_list_a_path)
if not samples_A:
self.logger.error(
f"Sample list A from '{sample_list_a_path.name}' is empty or all invalid. Aborting analysis.")
return
else: # sample_list_a_path is mandatory for this analysis
self.logger.error("Path to sample list A (for shared segments) is required but not provided.")
return
samples_B: Optional[List[str]] = None
if sample_list_b_path:
samples_B = self._load_samples_file(sample_list_b_path)
# If samples_B is empty after loading, it means specific check against empty set (all pass)
if not samples_B:
self.logger.warning(
f"Sample list B from '{sample_list_b_path.name}' is empty or all invalid. Specificity check will consider no exclusions.")
segment_list_shared, segment_list_specific = bam_manager.get_specific_and_shared_segments(
samples_list_A=samples_A,
samples_list_B=samples_B, # Can be None or empty list
filter_min_len=filter_len,
output_csv = output_csv
)
if output_csv:
self.logger.info("Extracting segment positions for shared and specific groups.")
# This call is now much faster
dict_segments = self.find_specific_groups_sample_position(segment_list_shared, samples_A)
csv_path_shared = Path(self.works_path / f"{self.gfa_name}_segment_shared{self.suffix}.csv")
self.logger.info(f"Generating shared segments CSV file: {csv_path_shared}")
# --- OPTIMIZATION 2: Use the `csv` module and better file handling ---
# Pre-sort samples and segments once to avoid sorting in the loop
sorted_samples_A = sorted(samples_A)
sorted_segments = sorted(list(segment_list_shared))
# --- Write the "shared" file ---
with open(csv_path_shared, "w", newline='') as f_out:
writer = csv.writer(f_out)
# Write headers
f_out.write(f'# Samples in list A: {",".join(sorted_samples_A)}\n')
header = ['NODE_ID'] + sorted_samples_A
writer.writerow(header)
# Write data rows
for segment in sorted_segments:
row = [segment]
segment_positions = dict_segments.get(segment, {})
for sample in sorted_samples_A:
positions = segment_positions.get(sample, [])
row.append(';'.join(positions)) # Join positions for the cell
writer.writerow(row)
# --- Write the "specific" file ONLY if needed ---
if samples_B:
csv_path_specific = Path(self.works_path / f"{self.gfa_name}_segment_specific{self.suffix}.csv")
self.logger.info(f"Generating specific segments CSV file: {csv_path_specific}")
# OPTIMIZATION 3: Clearer logic for finding specific segments
# A segment is specific if it's NOT in the difference, i.e., it IS in segment_list_specific
specific_only_segments = segment_list_shared.intersection(segment_list_specific)
with open(csv_path_specific, "w", newline='') as f_out:
writer = csv.writer(f_out)
# Write headers
f_out.write(f'# Samples in list A: {",".join(sorted_samples_A)}\n')
f_out.write(f'# Samples in list B: {",".join(samples_B)}\n')
header = ['NODE_ID'] + sorted_samples_A
writer.writerow(header)
# Write data rows
for segment in sorted(list(specific_only_segments)):
row = [segment]
segment_positions = dict_segments.get(segment, {})
for sample in sorted_samples_A:
positions = segment_positions.get(sample, [])
row.append(';'.join(positions))
writer.writerow(row)
[docs]
def find_specific_groups_sample_position(self, segment_list_shared, sample_list_a=None):
"""
Finds segment positions using a streaming approach to balance RAM and I/O performance.
It reads awk's output line by line without loading the full result into memory
or writing intermediate filtered BED files to disk.
"""
dict_segments = defaultdict(lambda: defaultdict(list))
# On garde l'idée du fichier temporaire unique pour les IDs, c'est très efficace.
with NamedTemporaryFile(mode='w+', delete=True, suffix="_segment_ids.txt") as tmp_segment_file:
tmp_segment_file.write('\n'.join(segment_list_shared))
tmp_segment_file.flush()
segment_id_filepath = tmp_segment_file.name
def _process_sample_via_streaming(sample_name):
"""
Worker function that filters a BED file and processes the output as a stream.
"""
sample_bed_path = self.bed_path / f"{sample_name}.bed"
if not sample_bed_path.exists():
self.logger.warning(f"BED file for sample '{sample_name}' not found. Skipping.")
return {}
awk_script = 'NR==FNR{ids[$1]; next} {id=$4; sub(/^[+-]/,"",id); if(id in ids) print $0}'
cmd = ['awk', awk_script, segment_id_filepath, str(sample_bed_path)]
local_result = defaultdict(lambda: defaultdict(list))
# LA MODIFICATION CLÉ EST ICI : On utilise Popen pour lancer le processus
# et on récupère son flux de sortie (stdout) via un "pipe".
try:
with subprocess.Popen(cmd, stdout=subprocess.PIPE, text=True, bufsize=1) as process:
# On lit la sortie ligne par ligne, au fur et à mesure que awk la produit.
# La RAM ne stocke jamais plus qu'une seule ligne à la fois.
for line in process.stdout:
parts = line.strip().split('\t')
if len(parts) < 4: continue # Ligne malformée
seg_id = parts[3].strip("+-")
seg_pos = f'{parts[0]}:{parts[1]}-{parts[2]}'
local_result[seg_id][sample_name].append(seg_pos)
# On vérifie que le processus s'est bien terminé après avoir lu la sortie.
if process.returncode != 0:
# Si vous voulez les erreurs, vous pouvez aussi piper stderr
self.logger.error(f"Awk process for '{sample_name}' failed with code {process.returncode}.")
return {} # Retourner un dict vide en cas d'échec
except Exception as e:
self.logger.error(f"Failed to execute awk for '{sample_name}': {e}")
return {}
return local_result
progress_bar = Progress(
TextColumn("[blue]{task.description}"), BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%", TimeElapsedColumn(),
transient=True
)
with progress_bar as progress:
with ThreadPoolExecutor(max_workers=self.threads) as executor:
task_id = progress.add_task("Filtering BED files (streaming)", total=len(sample_list_a))
futures = {
executor.submit(_process_sample_via_streaming, sample): sample
for sample in sample_list_a
}
for future in as_completed(futures):
partial_result = future.result()
for seg_id, sample_map in partial_result.items():
for sample_name, pos_list in sample_map.items():
dict_segments[seg_id][sample_name].extend(pos_list)
progress.update(task_id, advance=1)
return dict_segments
[docs]
def get_segments_by_depth(self, input_as_number: bool, lower_bound: int,
upper_bound: int, filter_len: int) -> Dict[str, int]:
"""
Retrieves segments within a specific depth range, using GratoolsBam.
Returns a dictionary of {segment_id: depth}.
"""
bam_manager = self._get_bam_manager()
if not bam_manager: return {}, {}
total_gfa_samples = len(self.available_sample_names)
if total_gfa_samples == 0 and not input_as_number:
self.logger.error("Cannot use percentage depth bounds as no GFA samples were found (total_gfa_samples=0).")
return {}, {}
return bam_manager.get_segments_and_positions_by_depth(
total_gfa_samples=total_gfa_samples,
input_as_number=input_as_number,
lower_bound_depth=lower_bound,
upper_bound_depth=upper_bound,
filter_min_len=filter_len,
bed_path=self.bed_path
)
[docs]
def display_or_save_segments_by_depth(self, input_as_number: bool, lower_bound: int,
upper_bound: int, filter_len: int,
output_to_file: bool) -> None:
"""
Retrieves segments by depth and either displays them in a Rich Table (if output_to_file is False)
or saves them to a CSV file (if output_to_file is True).
Parameters:
output_to_file (bool): If True, save to CSV. If False, print to terminal.
"""
segments_with_depth, segment_locations = self.get_segments_by_depth(
input_as_number, lower_bound, upper_bound, filter_len
)
if not segments_with_depth:
self.logger.info("No segments found matching the specified depth criteria.")
if not output_to_file: shared_console.print("[yellow]No segments found.[/yellow]")
return
total_samples = len(self.available_sample_names)
if output_to_file:
# --- Écriture du fichier CSV détaillé (logique de l'ancien code) ---
if not self.works_path:
self.logger.error("Works_path not set. Cannot save CSV.")
return
bound_type_str = 'individuals' if input_as_number else 'percent'
csv_filename = (f"{self.gfa_name or 'gfa'}_segments_by_depth_"
f"{lower_bound}-{upper_bound}_{bound_type_str}_len{filter_len}{self.suffix}.csv")
csv_output_path = self.works_path / csv_filename
try:
with open(csv_output_path, "w", newline='', encoding="utf-8") as f_csv:
# Utiliser le module CSV pour plus de robustesse, avec le tab comme délimiteur
writer = csv.writer(f_csv, delimiter='\t')
# Écrire l'en-tête
header = ["#NODE_ID", "SHARED_BY_COUNT", "SHARED_BY_PERCENT", "POSITIONS_BY_SAMPLE"]
writer.writerow(header)
for seg_id, depth in sorted(segments_with_depth.items()):
shared_by_percentage = round((depth / total_samples) * 100) if total_samples > 0 else 0
# Construire la chaîne complexe pour les positions
location_parts = []
locations_for_this_seg = segment_locations.get(seg_id, {})
for sample, positions in sorted(locations_for_this_seg.items()):
# Format "chrom,start,stop;"
pos_str = ";".join([f"{chrom},{start},{stop}" for chrom, start, stop in positions])
location_parts.append(f"{sample}:{pos_str};") # Ajoute le point-virgule final
full_location_string = "".join(location_parts)
writer.writerow([seg_id, depth, shared_by_percentage, full_location_string])
self.logger.info(f"Successfully saved detailed data to '{csv_output_path.name}'.")
except IOError as e:
self.logger.error(f"Failed to save CSV to '{csv_output_path}': {e}", exc_info=self.debug)
else:
# --- Affichage du résumé dans le terminal (logique simple et lisible) ---
table = Table(
title=f"Segments by Depth ({lower_bound}-{upper_bound} {'samples' if input_as_number else '%'}, len>={filter_len}bp)",
box=box.ROUNDED)
table.add_column("Segment ID", style="cyan", no_wrap=True)
table.add_column("Depth (Num Samples)", style="green", justify="right")
for seg_id, depth in segments_with_depth.items():
table.add_row(seg_id, str(depth))
shared_console.print(Align.center(table))
self.logger.info(
f"Processed {len(segments_with_depth)} segments with depth between {lower_bound} and {upper_bound} "
f"({'individuals' if input_as_number else '%'}) and length >= {filter_len}bp."
)
[docs]
def export_to_bandage_csv(self, output_csv_path: Optional[Path] = None) -> None:
"""
Exports node (segment) information to a CSV file compatible with Bandage.
This method uses the indexed BAM file to calculate properties like length and depth.
Args:
output_csv_path (Optional[Path]): The path to save the output CSV file.
If None, a default path is generated in the GFA directory.
"""
self.logger.info("Starting export of node data to a Bandage-compatible CSV.")
if not self.bam_segments_file or not self.bam_segments_file.exists():
self.logger.error(
f"BAM file not found at '{self.bam_segments_file}'. "
"The GFA import is required. Please run the 'import' command first."
)
return
# Determine the final output path for the CSV
if output_csv_path is None:
# Default filename, e.g., "my_graph_bandage.csv"
output_csv_path = self.works_path / f"{self.gfa_name}_bandage.csv"
self.logger.info(f"Using BAM file: {self.bam_segments_file}")
self.logger.info(f"Output CSV will be saved to: {output_csv_path}")
try:
# Instantiate GratoolsBam with tagging=True, as required for this export
bam_gratools = GratoolsBam(
bam_path=self.bam_segments_file,
threads=self.threads,
tagging=True, # Essential for calculating coverage/depth for Bandage
logger=self.logger
)
# Call the export method
bam_gratools.export_nodes_to_csv(output_csv_path=output_csv_path)
self.logger.info(f"Successfully exported Bandage CSV to {output_csv_path}")
except Exception as e:
self.logger.error(f"An error occurred during the CSV export for Bandage: {e}", exc_info=self.debug)
if __name__ == "__main__":
# Example usage (ensure paths are correct for your system)
# This block should ideally be in a separate script or guarded more carefully if this file is imported.
# For testing, it's okay here.
# --- Configuration for Testing ---
# Replace with an actual small GFA file path for testing
# test_gfa_path = Path("/path/to/your/test_gfa_file.gfa.gz")
# test_gfa_path = Path("test_data/small_test.gfa") # Example path
test_gfa_path = Path("/shared/home/sravel/glp701/sandbox/toolbox/Og_cactus.gfa.gz") # Path from original code
import_links = False
if not test_gfa_path.exists():
print(f"Test GFA file not found at: {test_gfa_path}. Skipping example usage.", file=sys.stderr)
# Create a dummy GFA for basic testing if needed
# dummy_gfa_content = "H\tVN:Z:1.0\nS\ts1\tACGT\nS\ts2\tTTTT\nL\ts1\t+\ts2\t+\t0M\nW\tsampleA\t0\tchr1\t0\t100\t>s1>s2\n"
# test_gfa_path = Path("dummy_test.gfa")
# with open(test_gfa_path, "w") as f: f.write(dummy_gfa_content)
# print(f"Created dummy GFA: {test_gfa_path} for testing.", file=sys.stderr)
else:
print(f"--- Running GraTools Example Usage with GFA: {test_gfa_path} ---")
# Define metadata for Gratools initialization (simulates CLI args)
gratools_meta = {
"verbosity": "INFO", # "DEBUG" for more detail
# "log_path": Path("gratools_test_logs"), # Optional: custom log directory
"threads": 2, # Number of threads for GFA parsing and BAM ops
}
# Initialize Gratools
# This will trigger __post_init__ which does GFA importing, etc.
try:
gratools_instance = Gratools(
gfa_path=test_gfa_path,
threads=gratools_meta.get("threads", 1),
# outdir=Path("gratools_test_output"), # Optional: custom main output dir
meta=gratools_meta,
import_links=import_links, # True to test DB link importing
# --- Query parameters for subgraph extraction (example) ---
# sample_name_query="sampleA", # From dummy GFA or your test GFA
# chromosome_query="chr1",
# start_query=0,
# stop_query=100, # Will be adjusted to chrom size if larger
# merge=0, # Merge distance for bedtools
)
# --- Example Operations ---
# 1. Display GFA Statistics
print("\n--- Displaying GFA Statistics ---")
gratools_instance.display_gfa_statistics(by_category=True)
# 2. Display Sample and Chromosome Info
print("\n--- Displaying Sample Names ---")
gratools_instance.display_available_sample_names()
# gratools_instance.save_available_sample_names() # Optionally save
print("\n--- Displaying Chromosome Summary per Sample ---")
gratools_instance.display_chromosomes_summary()
# gratools_instance.save_chromosomes_summary_by_sample() # Optionally save
# 3. Example Subgraph Extraction (if query parameters were set)
if gratools_instance.sample_name_query:
print(f"\n--- Extracting Subgraph for Query: {gratools_instance.sample_name_query} ---")
gratools_instance.get_subgraph(
# samples_list_path=Path("other_samples.txt"), # Optional: file with other samples
all_samples_flag=False, # Process only query, or query + samples_list_path
build_fasta_flag=True
)
# Combined GFA/FASTA are generated within get_subgraph if conditions met.
else:
print("\n--- Skipping Subgraph Extraction (sample_name_query not set) ---")
# 4. Example BAM Analysis (Core/Dispensable)
# Ensure GFA has multiple samples for this to be meaningful
if len(gratools_instance.available_sample_names) > 1:
print("\n--- Running Core/Dispensable Ratio Analysis ---")
gratools_instance.run_pan_ratio_analysis(
input_as_number=False, # Use percentages
shared_min=80, # 80% shared for core
specific_max=20, # 20% specific for dispensable
filter_len=50 # Min segment length 50bp
)
else:
print("\n--- Skipping Core/Dispensable Analysis (needs >1 sample in GFA) ---")
# 5. Get Segments by Depth
print("\n--- Getting Segments by Depth (e.g., in 10-50% of samples) ---")
# segments_found = gratools_instance.get_segments_by_depth(
# input_as_number=False, lower_bound=10, upper_bound=50, filter_len=0
# )
# print(f"Found {len(segments_found)} segments. First 5: {dict(list(segments_found.items())[:5])}")
gratools_instance.display_or_save_segments_by_depth(
input_as_number=False, lower_bound=10, upper_bound=50, filter_len=0, output_to_file=False
# Display to terminal
)
gratools_instance.display_or_save_segments_by_depth(
input_as_number=False, lower_bound=10, upper_bound=50, filter_len=0, output_to_file=True # Save to file
)
print("\n--- GraTools Example Usage Complete ---")
except FileNotFoundError as e:
print(f"ERROR in example: {e}. Ensure test GFA path is correct.", file=sys.stderr)
except EnvironmentError as e:
print(f"ENVIRONMENT ERROR in example: {e}. Check dependencies (e.g., bedtools).", file=sys.stderr)
except Exception as e:
print(f"UNEXPECTED ERROR in example usage: {e}", file=sys.stderr)
import traceback
traceback.print_exc()