Latest Threat Research:SANDWORM_MODE: Shai-Hulud-Style npm Worm Hijacks CI Workflows and Poisons AI Toolchains.Details
Socket
Book a DemoInstallSign in
Socket

abutils

Package Overview
Dependencies
Maintainers
1
Versions
41
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

abutils - npm Package Compare versions

Comparing version
0.5.0
to
0.5.1
+542
abutils/tests/test_preprocessing.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import shutil
import tempfile
from pathlib import Path
from unittest import mock
import polars as pl
import pytest
from abutils import Sequence
from abutils.tools import preprocessing
# Fixtures
@pytest.fixture
def temp_dir():
"""Create a temporary directory for test files."""
dir_path = tempfile.mkdtemp(prefix="tempdir")
yield dir_path
shutil.rmtree(dir_path)
@pytest.fixture
def fastq_file(temp_dir):
"""Create a simple FASTQ file for testing."""
fastq_path = os.path.join(temp_dir, "test_file.fastq")
with open(fastq_path, "w") as f:
f.write("test content")
return fastq_path
@pytest.fixture
def illumina_fastq(temp_dir):
"""Create a FASTQ file with Illumina naming convention."""
fastq_path = os.path.join(temp_dir, "Sample1_S1_L001_R1_001.fastq")
with open(fastq_path, "w") as f:
f.write("@header1\nACGT\n+\nIIII\n")
return fastq_path
@pytest.fixture
def element_fastq(temp_dir):
"""Create a FASTQ file with Element naming convention."""
fastq_path = os.path.join(temp_dir, "Sample1_R1.fastq")
with open(fastq_path, "w") as f:
f.write("@header1\nACGT\n+\nIIII\n")
return fastq_path
@pytest.fixture
def illumina_pair(temp_dir):
"""Create a pair of Illumina FASTQ files."""
r1_path = os.path.join(temp_dir, "Sample1_S1_L001_R1_001.fastq")
r2_path = os.path.join(temp_dir, "Sample1_S1_L001_R2_001.fastq")
with open(r1_path, "w") as f:
f.write("@header1\nACGT\n+\nIIII\n")
with open(r2_path, "w") as f:
f.write("@header1\nTGCA\n+\nIIII\n")
return r1_path, r2_path
@pytest.fixture
def merged_dir(temp_dir):
"""Create an output directory for merged files."""
merged_dir = os.path.join(temp_dir, "merged")
os.makedirs(merged_dir, exist_ok=True)
return merged_dir
@pytest.fixture
def log_dir(temp_dir):
"""Create a directory for log files."""
log_dir = os.path.join(temp_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
return log_dir
@pytest.fixture
def multi_sample_fastqs(temp_dir):
"""Create multiple sample FASTQ files."""
r1_path = os.path.join(temp_dir, "Sample1_S1_L001_R1_001.fastq")
r2_path = os.path.join(temp_dir, "Sample1_S1_L001_R2_001.fastq")
r1_path2 = os.path.join(temp_dir, "Sample2_S2_L001_R1_001.fastq")
r2_path2 = os.path.join(temp_dir, "Sample2_S2_L001_R2_001.fastq")
for path in [r1_path, r2_path, r1_path2, r2_path2]:
with open(path, "w") as f:
f.write("@header1\nACGT\n+\nIIII\n")
return {
"dirname": temp_dir,
"r1_path": r1_path,
"r2_path": r2_path,
"r1_path2": r1_path2,
"r2_path2": r2_path2,
}
@pytest.fixture
def airr_files(temp_dir):
"""Create temporary AIRR files for testing."""
sample1_path = os.path.join(temp_dir, "sample1.tsv")
sample2_path = os.path.join(temp_dir, "sample2.tsv")
# Create a simple AIRR dataframe
data1 = {
"sequence_id": ["seq1", "seq2", "seq3"],
"sequence": ["ACGT", "ACGT", "TGCA"],
"umi": ["UMI1", "UMI2", "UMI3"],
"v_gene": ["IGHV1-1*01", "IGHV1-2*01", "IGHV2-1*01"],
"j_gene": ["IGHJ1*01", "IGHJ2*01", "IGHJ1*01"],
"locus": ["IGH", "IGH", "IGH"],
}
data2 = {
"sequence_id": ["seq4", "seq5", "seq6"],
"sequence": ["ACGT", "GTCA", "TGCA"],
"umi": ["UMI4", "UMI5", "UMI6"],
"v_gene": ["IGHV1-1*01", "IGHV1-2*01", "IGHV2-1*01"],
"j_gene": ["IGHJ1*01", "IGHJ2*01", "IGHJ1*01"],
"locus": ["IGH", "IGH", "IGH"],
}
df1 = pl.DataFrame(data1)
df2 = pl.DataFrame(data2)
df1.write_csv(sample1_path, separator="\t")
df2.write_csv(sample2_path, separator="\t")
# Create output directory
output_dir = os.path.join(temp_dir, "output")
os.makedirs(output_dir, exist_ok=True)
return {
"dirname": temp_dir,
"sample1_path": sample1_path,
"sample2_path": sample2_path,
"output_dir": output_dir,
}
# Tests for FASTQFile
def test_fastq_file_init(fastq_file):
"""Test initialization of FASTQFile."""
fastq_obj = preprocessing.FASTQFile(fastq_file)
assert fastq_obj.file == fastq_file
assert fastq_obj.path == os.path.abspath(fastq_file)
assert fastq_obj.basename == "test_file.fastq"
assert fastq_obj.dir == os.path.dirname(os.path.abspath(fastq_file))
assert fastq_obj.filename == "test_file"
# Tests for IlluminaFile
def test_illumina_file_init(illumina_fastq):
"""Test initialization of IlluminaFile."""
illumina_file = preprocessing.IlluminaFile(illumina_fastq)
assert illumina_file.schema == "illumina"
assert illumina_file.filename == "Sample1_S1_L001_R1_001"
def test_illumina_file_properties(illumina_fastq):
"""Test properties of IlluminaFile."""
illumina_file = preprocessing.IlluminaFile(illumina_fastq)
assert illumina_file.name == "Sample1"
assert illumina_file.sample == "S1"
assert illumina_file.lane == "L001"
assert illumina_file.read == "R1"
assert illumina_file.number == "001"
# Tests for ElementFile
def test_element_file_init(element_fastq):
"""Test initialization of ElementFile."""
element_file = preprocessing.ElementFile(element_fastq)
assert element_file.schema == "element"
assert element_file.filename == "Sample1_R1"
def test_element_file_properties(element_fastq):
"""Test properties of ElementFile."""
element_file = preprocessing.ElementFile(element_fastq)
assert element_file.name == "Sample1"
assert element_file.read == "R1"
assert element_file.lane == "" # Element files don't have lane info
assert element_file.sample == "" # Element files don't have sample info
assert element_file.number == "" # Element files don't have number info
# Tests for MergeGroup
def test_merge_group_init(illumina_pair):
"""Test initialization of MergeGroup."""
r1_path, r2_path = illumina_pair
r1_file = preprocessing.IlluminaFile(r1_path)
r2_file = preprocessing.IlluminaFile(r2_path)
merge_group = preprocessing.MergeGroup("Sample1", [r1_file, r2_file])
assert merge_group.name == "Sample1"
assert len(merge_group.files) == 2
assert merge_group.merged_file is None
def test_merge_group_by_lane(illumina_pair):
"""Test _group_by_lane method of MergeGroup."""
r1_path, r2_path = illumina_pair
r1_file = preprocessing.IlluminaFile(r1_path)
r2_file = preprocessing.IlluminaFile(r2_path)
merge_group = preprocessing.MergeGroup("Sample1", [r1_file, r2_file])
groups = merge_group._group_by_lane()
# Should have one group for L001
assert len(groups) == 1
# Group should contain both R1 and R2
assert len(groups[0]) == 2
# Files should be in the group
assert r1_file in groups[0]
assert r2_file in groups[0]
def test_merge_group_merge(illumina_pair, merged_dir, log_dir):
"""Test merge method of MergeGroup."""
r1_path, r2_path = illumina_pair
r1_file = preprocessing.IlluminaFile(r1_path)
r2_file = preprocessing.IlluminaFile(r2_path)
merge_group = preprocessing.MergeGroup("Sample1", [r1_file, r2_file])
result = merge_group.merge(
merged_directory=merged_dir,
log_directory=log_dir,
format="fastq",
algo="fastp",
verbose=True,
)
# Check result
expected_path = os.path.join(merged_dir, "Sample1.fastq")
assert result == expected_path
assert merge_group.merged_file == expected_path
# Tests for group_fastq_pairs
def test_group_fastq_pairs(multi_sample_fastqs):
"""Test grouping of FASTQ pairs."""
r1_file = preprocessing.IlluminaFile(multi_sample_fastqs["r1_path"])
r2_file = preprocessing.IlluminaFile(multi_sample_fastqs["r2_path"])
r1_file2 = preprocessing.IlluminaFile(multi_sample_fastqs["r1_path2"])
r2_file2 = preprocessing.IlluminaFile(multi_sample_fastqs["r2_path2"])
files = [r1_file, r2_file, r1_file2, r2_file2]
merge_groups = preprocessing.group_fastq_pairs(files)
# Should have two merge groups
assert len(merge_groups) == 2
# Check that groups are correctly formed
sample1_group = next((g for g in merge_groups if g.name == "Sample1"), None)
sample2_group = next((g for g in merge_groups if g.name == "Sample2"), None)
assert sample1_group is not None
assert sample2_group is not None
assert len(sample1_group.files) == 2
assert len(sample2_group.files) == 2
def test_merge_fastqs(multi_sample_fastqs, merged_dir, log_dir):
"""Test merge_fastqs function."""
result = preprocessing.merge_fastqs(
files=multi_sample_fastqs["dirname"],
output_directory=merged_dir,
log_directory=log_dir,
schema="illumina",
verbose=True,
)
# Should return a list of merged files
assert len(result) > 0
# The merged file should be in the output directory
assert os.path.dirname(result[0]) == merged_dir
def test_merge_fastqs_fastp(illumina_pair, merged_dir, log_dir):
"""Test merge_fastqs_fastp function."""
r1_path, r2_path = illumina_pair
merged_file = os.path.join(merged_dir, "merged.fastq")
result = preprocessing.merge_fastqs_fastp(
forward=r1_path,
reverse=r2_path,
merged=merged_file,
log_directory=log_dir,
name="Sample1",
)
# Check result
assert result == merged_file
@mock.patch("abutils.tools.preprocessing.sp.Popen")
def test_merge_fastqs_vsearch(mock_popen, illumina_pair, merged_dir):
"""Test merge_fastqs_vsearch function."""
r1_path, r2_path = illumina_pair
# Need to mock subprocess.Popen for external commands
mock_process = mock.MagicMock()
mock_process.communicate.return_value = (b"", b"")
mock_process.returncode = 0
mock_popen.return_value = mock_process
merged_file = os.path.join(merged_dir, "merged.fasta")
result = preprocessing.merge_fastqs_vsearch(
forward=r1_path,
reverse=r2_path,
merged_file=merged_file,
output_format="fasta",
)
# Check result
assert result == merged_file
# Check that sp.Popen was called with vsearch command
cmd = mock_popen.call_args[0][0]
assert "vsearch" in cmd
assert "--fastq_mergepairs" in cmd
assert r1_path in cmd
assert r2_path in cmd
assert merged_file in cmd
# Test helper functions
def test_deduplicate_sequences():
"""Test _deduplicate_sequences function."""
df = pl.DataFrame(
{"sequence_id": ["seq1", "seq2", "seq3"], "sequence": ["ACGT", "ACGT", "TGCA"]}
)
# Test without keeping read numbers
result = preprocessing._deduplicate_sequences(df, ["sequence"], False)
assert result.height == 2 # Should have 2 unique sequences
# Test with keeping read numbers
result = preprocessing._deduplicate_sequences(df, ["sequence"], True)
assert result.height == 2 # Should have 2 unique sequences
assert "count" in result.columns # Should have a count column
assert result.select("count").to_series().to_list() == [2, 1] # Check count values
# # Test for deduplicate and reduction functions
# @mock.patch("abutils.tools.preprocessing.to_fasta")
# def test_deduplicate(mock_to_fasta, airr_files, temp_dir):
# """Test deduplicate function."""
# # Mock to_fasta to avoid actual file writing
# mock_to_fasta.return_value = None
# output_dir = os.path.join(temp_dir, "dedup")
# preprocessing.deduplicate(
# # project_folder=airr_files["sample1_path"], # Test file path input
# project_folder=airr_files["dirname"], # Test file path input
# output="dedup",
# output_format="fasta",
# debug=True,
# )
# # Check that to_fasta was called with deduped data
# assert mock_to_fasta.called
# # The first argument should be a list of tuples with sequence_id and sequence
# # We can't assert exact values because of potential mock issues, but we can check structure
# args = mock_to_fasta.call_args[0]
# assert isinstance(args[0], list)
@mock.patch("abutils.tools.cluster.cluster")
@mock.patch("abutils.tools.alignment.make_consensus")
def test_process_chains(mock_make_consensus, mock_cluster):
"""Test _process_chains function."""
# Setup mock cluster
mock_cluster_obj = mock.MagicMock()
mock_cluster_obj.size = 3
mock_cluster_obj.centroid = mock.MagicMock()
mock_cluster_obj.centroid.id = "seq1"
mock_cluster_obj.centroid.sequence = "ACGT"
mock_cluster_obj.sequences = [
Sequence("ACGT", id="seq1"),
Sequence("ACGT", id="seq2"),
Sequence("ACGT", id="seq3"),
]
mock_cluster.return_value = [mock_cluster_obj]
mock_consensus = Sequence("ACGT", id="consensus")
mock_make_consensus.return_value = mock_consensus
# Create test dataframe
df = pl.DataFrame(
{
"sequence_id": ["seq1", "seq2", "seq3"],
"sequence": ["ACGT", "ACGT", "ACGT"],
"v_gene": ["IGHV1-1*01", "IGHV1-1*01", "IGHV1-1*01"],
"j_gene": ["IGHJ1*01", "IGHJ1*01", "IGHJ1*01"],
"vj_bin": [
"IGHV1-1*01_IGHJ1*01",
"IGHV1-1*01_IGHJ1*01",
"IGHV1-1*01_IGHJ1*01",
],
}
)
# Test centroid mode
result = preprocessing._process_chains(
"test",
df,
min_cluster_size=2,
clustering_threshold=0.9,
consentroid="centroid",
cluster_sizes_separator="|",
keep_cluster_sizes=True,
output_format="fasta",
debug=True,
)
assert len(result) == 1
assert result[0].id == "seq1|3"
assert result[0].sequence == "ACGT"
# Test consensus mode
result = preprocessing._process_chains(
"test",
df,
min_cluster_size=2,
clustering_threshold=0.9,
consentroid="consensus",
cluster_sizes_separator="|",
keep_cluster_sizes=True,
output_format="fasta",
debug=True,
)
assert len(result) == 1
assert result[0].id == "consensus|3"
assert result[0].sequence == "ACGT"
def test_process_chain_group():
"""Test _process_chain_group function."""
# Create test dataframe
df = pl.DataFrame(
{
"sequence_id": ["seq1", "seq2", "seq3"],
"sequence": ["ACGT", "ACGT", "TGCA"],
"v_gene": ["IGHV1-1*01", "IGHV1-1*01", "IGHV2-1*01"],
"j_gene": ["IGHJ1*01", "IGHJ1*01", "IGHJ2*01"],
"locus": ["IGH", "IGH", "IGH"],
}
)
with mock.patch(
"abutils.tools.preprocessing._process_chains"
) as mock_process_chains:
mock_process_chains.return_value = [Sequence("ACGT", id="seq1")]
# Test without UMI
result = preprocessing._process_chain_group(
df,
chain_type="Heavy",
umi=False,
min_cluster_size=2,
clustering_threshold=0.9,
consentroid="centroid",
cluster_sizes_separator="|",
keep_cluster_sizes=False,
output_format="fasta",
debug=True,
)
assert result == [Sequence("ACGT", id="seq1")]
# Check that vj_bin was created correctly
assert "vj_bin" in mock_process_chains.call_args[0][1].columns
vj_bins = (
mock_process_chains.call_args[0][1].select("vj_bin").to_series().to_list()
)
assert "IGHV1-1*01_IGHJ1*01" in vj_bins
assert "IGHV2-1*01_IGHJ2*01" in vj_bins
# Test with UMI
df_with_umi = df.with_columns(pl.lit("UMI1").alias("umi"))
result = preprocessing._process_chain_group(
df_with_umi,
chain_type="Heavy",
umi=True,
min_cluster_size=2,
clustering_threshold=0.9,
consentroid="centroid",
cluster_sizes_separator="|",
keep_cluster_sizes=False,
output_format="fasta",
debug=True,
)
assert result == [Sequence("ACGT", id="seq1")]
# Check that vj_bin includes UMI
vj_bins = (
mock_process_chains.call_args[0][1].select("vj_bin").to_series().to_list()
)
assert any("UMI1" in bin_name for bin_name in vj_bins)
# @mock.patch("abutils.tools.preprocessing.to_fasta")
# @mock.patch("abutils.tools.preprocessing._process_chain_group")
# def test_reduction(mock_process_chain_group, mock_to_fasta, airr_files, temp_dir):
# """Test reduction function."""
# mock_process_chain_group.return_value = [
# Sequence("ACGT", id="seq1"),
# Sequence("TGCA", id="seq2"),
# ]
# mock_to_fasta.return_value = None
# output_dir = os.path.join(temp_dir, "reduced")
# preprocessing.reduction(
# project_folder=airr_files["sample1_path"], # Test file path input
# output="reduced",
# output_format="fasta",
# debug=True,
# )
# # Check that _process_chain_group and to_fasta were called
# assert mock_process_chain_group.called
# assert mock_to_fasta.called
#!/usr/bin/env python
# filename: pipeline.py
#
# Copyright (c) 2015 Bryan Briney
# License: The MIT license (http://opensource.org/licenses/MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
# and associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
import glob
import os
import re
from typing import Iterable, Optional, Union
from natsort import natsorted
# =======================
# General file I/O
# =======================
def make_dir(directory: str) -> None:
"""
Makes a directory, if it doesn't already exist.
Parameters
----------
directory : str
Path to a directory.
"""
if not os.path.exists(directory):
os.makedirs(os.path.abspath(directory))
def list_files(
directory: str,
extension: Union[str, Iterable, None] = None,
recursive: bool = False,
match: Optional[str] = None,
ignore_dot_files: bool = True,
) -> Iterable[str]:
"""
Lists files in a given directory.
Parameters
----------
directory : str
Path to a directory. If a file path is passed instead, the returned list of files will contain
only that file path.
extension : str
If supplied, only files that end with the specificied extension(s) will be returned. Can be either
a string or a list of strings. Extension evaluation is case-insensitive and can match complex
extensions (e.g. '.fastq.gz'). Default is ``None``, which returns all files in the directory,
regardless of extension.
recursive : bool, default=False
If ``True``, the directory will be searched recursively, and all files in all subdirectories will be returned.
match : str, optional
If supplied, only files that match the specified pattern will be returned. Regular expressions are supported.
ignore_dot_files : bool, default=True
If ``True``, dot files (hidden files) will be ignored.
Returns
-------
Iterable[str]
"""
directory = os.path.abspath(directory)
if os.path.exists(directory):
if os.path.isdir(directory):
if recursive:
files = []
for root, dirs, _files in os.walk(directory):
for f in _files:
file_path = os.path.join(root, f)
files.append(file_path)
else:
files = natsorted(glob.glob(directory + "/*"))
else:
files = [directory]
else:
raise ValueError(f"Directory {directory} does not exist.")
if extension is not None:
if isinstance(extension, str):
extension = [
extension,
]
files = [
f
for f in files
if any(
[
any([f.lower().endswith(e.lower()) for e in extension]),
any([f.endswith(e.upper()) for e in extension]),
any([f.endswith(e.lower()) for e in extension]),
]
)
]
if ignore_dot_files:
files = [f for f in files if not os.path.basename(f).startswith(".")]
if match is not None:
files = [f for f in files if re.match(match, os.path.basename(f)) is not None]
return files
def rename_file(file: str, new_name: str) -> None:
"""
Renames a file.
Parameters
----------
file : str
Path to the file to be renamed.
new_name : str
New name for the file.
"""
os.rename(file, new_name)
def delete_files(files: Union[str, Iterable]) -> None:
"""
Deletes files.
Parameters
----------
files : Union[str, Iterable]
Path to a file or an iterable of file paths.
"""
if isinstance(files, str):
files = [
files,
]
for f in files:
if os.path.exists(f):
os.remove(f)
# from __future__ import absolute_import, division, print_function, unicode_literals
# import glob
# import os
# import sys
# from typing import Iterable, Optional, Union
# from . import log
# # for backward compatibility
# def list_files(*args, **kwargs):
# from ..io import list_files
# return list_files(*args, **kwargs)
# def make_dir(*args, **kwargs):
# from ..io import make_dir
# return make_dir(*args, **kwargs)
# def initialize(log_file, project_dir=None, debug=False):
# """
# Initializes an AbTools pipeline.
# Initialization includes printing the AbTools splash, setting up logging,
# creating the project directory, and logging both the project directory
# and the log location.
# Parameters
# ----------
# log_file : str
# Path to the log file. Required.
# project_dir : str
# Path to the project directory. If not provided,
# the project directory won't be created and the location won't be logged.
# debug : bool
# If ``True``, the logging level will be set to ``logging.DEBUG``.
# Returns
# -------
# logger
# A logger instance.
# """
# print_splash()
# log.setup_logging(log_file, print_log_location=False, debug=debug)
# logger = log.get_logger("pipeline")
# if project_dir is not None:
# make_dir(os.path.normpath(project_dir))
# logger.info("PROJECT DIRECTORY: {}".format(project_dir))
# logger.info("")
# logger.info("LOG LOCATION: {}".format(log_file))
# print("")
# return logger
# # def make_dir(directory: str) -> None:
# # """
# # Makes a directory, if it doesn't already exist.
# # Parameters
# # ----------
# # directory : str
# # Path to a directory.
# # """
# # if not os.path.exists(directory):
# # os.makedirs(os.path.abspath(directory))
# # def list_files(
# # directory: str, extension: Union[str, Iterable, None] = None
# # ) -> Iterable[str]:
# # """
# # Lists files in a given directory.
# # Parameters
# # ----------
# # directory : str
# # Path to a directory.
# # extension : str
# # If supplied, only files that end with the specificied extension(s) will be returned. Can be either
# # a string or a list of strings. Extension evaluation is case-insensitive and can match complex
# # extensions (e.g. '.fastq.gz'). Default is ``None``, which returns all files in the directory,
# # regardless of extension.
# # Returns
# # -------
# # Iterable[str]
# # """
# # if os.path.isdir(directory):
# # expanded_dir = os.path.expanduser(directory)
# # files = sorted(glob.glob(expanded_dir + "/*"))
# # else:
# # files = [
# # directory,
# # ]
# # if extension is not None:
# # if isinstance(extension, str):
# # extension = [
# # extension,
# # ]
# # files = [
# # f
# # for f in files
# # if any(
# # [
# # any([f.lower().endswith(e.lower()) for e in extension]),
# # any([f.endswith(e.upper()) for e in extension]),
# # any([f.endswith(e.lower()) for e in extension]),
# # ]
# # )
# # ]
# # return files
# def print_splash():
# splash = """
# _ _ _ _ _ _ _ ____ _ _ _
# / \ | |__ | | | | |_(_) |___ | _ \(_)_ __ ___| (_)_ __ ___
# / _ \ | '_ \| | | | __| | / __| | |_) | | '_ \ / _ \ | | '_ \ / _ \\
# / ___ \| |_) | |_| | |_| | \__ \ | __/| | |_) | __/ | | | | | __/
# /_/ \_\_.__/ \___/ \__|_|_|___/ |_| |_| .__/ \___|_|_|_| |_|\___|
# |_|
# """
# print("")
# print(splash)
# print(
# "(c) 2023 Bryan Briney\nDistributed under the MIT License (http://opensource.org/licenses/MIT)"
# )
# print("")
+5
-8
Metadata-Version: 2.4
Name: abutils
Version: 0.5.0
Version: 0.5.1
Summary: Utilities for analysis of adaptive immune receptor repertoire (AIRR) data

@@ -13,8 +13,8 @@ Home-page: https://github.com/briney/abutils

Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Requires-Python: >=3.8
Requires-Python: >=3.10
Description-Content-Type: text/markdown

@@ -25,5 +25,3 @@ License-File: LICENSE

Requires-Dist: biopython>=1.78
Requires-Dist: celery
Requires-Dist: dnachisel
Requires-Dist: ete3
Requires-Dist: fastcluster

@@ -42,3 +40,2 @@ Requires-Dist: matplotlib

Requires-Dist: pyfastx
Requires-Dist: pymongo
Requires-Dist: pytest

@@ -106,3 +103,3 @@ Requires-Dist: python-circos

### requirements
**python 3.8+**
**python 3.10+**

@@ -109,0 +106,0 @@ abstar

abstar>=0.6.3
baltic
biopython>=1.78
celery
dnachisel
ete3
fastcluster

@@ -20,3 +18,2 @@ matplotlib

pyfastx
pymongo
pytest

@@ -23,0 +20,0 @@ python-circos

@@ -156,2 +156,3 @@ LICENSE

abutils/tests/test_phylo.py
abutils/tests/test_preprocessing.py
abutils/tests/test_sequence.py

@@ -184,4 +185,4 @@ abutils/tests/test_utilities.py

abutils/utils/mongodb.py
abutils/utils/path.py
abutils/utils/phylogeny.py
abutils/utils/pipeline.py
abutils/utils/progbar.py

@@ -188,0 +189,0 @@ abutils/utils/s3.py

@@ -14,26 +14,25 @@ # from .core import *

# from .tools import log
from .tools.phylo import Phylogeny
from .utils import (
alignment,
cluster,
# color,
decorators,
jobs,
# log,
phylogeny,
pipeline,
progbar,
s3,
seqio,
utilities,
)
from .utils import alignment as aln
from .utils import pipeline as path
# from .utils.alignment import SSWAlignment, NWAlignment
from .utils.alignment import GlobalAlignment, LocalAlignment, SemiGlobalAlignment
from .utils.jobs import *
from .utils.progbar import progress_bar
# from .tools.phylo import Phylogeny
# from .utils import (
# alignment,
# cluster,
# # color,
# decorators,
# jobs,
# path,
# # log,
# phylogeny,
# progbar,
# s3,
# seqio,
# utilities,
# )
# from .utils import alignment as aln
# from .utils import path as path
# # from .utils.alignment import SSWAlignment, NWAlignment
# from .tools.alignment import GlobalAlignment, LocalAlignment, SemiGlobalAlignment
# from .utils.jobs import *
# from .utils.progbar import progress_bar
from .version import __version__
BINARY_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "binaries"))

@@ -73,5 +73,4 @@ #!/usr/bin/env python

)
from .utils.path import delete_files, list_files, make_dir, rename_file
# from .utils.convert import abi_to_fasta
# =======================

@@ -78,0 +77,0 @@ # General file I/O

@@ -25,3 +25,5 @@ #!/usr/bin/env python

from .preprocess.preprocess import deduplicate
from .preprocess.preprocess import reduction
# from .preprocess.preprocess import deduplicate
# from .preprocess.preprocess import reduction
from .tools.preprocessing import *

@@ -1,1 +0,1 @@

__all__ = ['preprocess']
# __all__ = ['preprocess']

@@ -26,571 +26,566 @@ #!/usr/bin/env python

import os
import time
import datetime
from typing import Optional
from tqdm.auto import tqdm
# import os
# import time
# import datetime
# from typing import Optional
# from tqdm.auto import tqdm
import polars as pl
from .. import Sequence
from ..io import list_files, make_dir, to_fasta, from_polars, to_polars
from ..tools import cluster, alignment
# import polars as pl
# from .. import Sequence
# from ..io import list_files, make_dir, to_fasta, from_polars, to_polars
# from ..tools import cluster, alignment
__all__ = ['deduplicate', 'reduction']
# __all__ = ['deduplicate', 'reduction']
# #########################################
# # #
# # Helper functions #
# # #
# #########################################
#########################################
# #
# Helper functions #
# #
#########################################
# def _deduplicate_sequences(
# df: pl.DataFrame,
# group_by_cols: list,
# keep_read_numbers: bool
# ):
def _deduplicate_sequences(
df: pl.DataFrame,
group_by_cols: list,
keep_read_numbers: bool
):
# if not keep_read_numbers:
# return df.unique(subset=group_by_cols).sort("sequence")
# return (
# df.group_by(group_by_cols).agg(
# pl.count().alias("count"),
# pl.col("sequence_id").first(),
# ).sort("sequence")
# )
if not keep_read_numbers:
return df.unique(subset=group_by_cols).sort("sequence")
return (
df.group_by(group_by_cols).agg(
pl.count().alias("count"),
pl.col("sequence_id").first(),
).sort("sequence")
)
# def _write_fasta(
# df: pl.DataFrame,
# file_path: str,
# keep_read_numbers: bool,
# read_number_separator: str
# ):
def _write_fasta(
df: pl.DataFrame,
file_path: str,
keep_read_numbers: bool,
read_number_separator: str
):
# if keep_read_numbers:
# df = df.with_columns(
# (pl.col("sequence_id") + read_number_separator + pl.col("count").cast(pl.Utf8)).alias("sequence_id")
# )
# tuples_list = list(zip(df["sequence_id"].to_list(), df["sequence"].to_list()))
# to_fasta(tuples_list, file_path)
# print(f"Output written to FASTA file: {file_path}\n")
if keep_read_numbers:
df = df.with_columns(
(pl.col("sequence_id") + read_number_separator + pl.col("count").cast(pl.Utf8)).alias("sequence_id")
)
tuples_list = list(zip(df["sequence_id"].to_list(), df["sequence"].to_list()))
to_fasta(tuples_list, file_path)
print(f"Output written to FASTA file: {file_path}\n")
# def _process_chains(
# step: str,
# chains: pl.DataFrame,
# min_cluster_size: int,
# clustering_threshold: float,
# consentroid: str,
# cluster_sizes_separator: str,
# keep_cluster_sizes: bool,
# output_format: str,
# debug: bool
# ):
# """Internal function to generate a consentroid from a bin of sequences"""
def _process_chains(
step: str,
chains: pl.DataFrame,
min_cluster_size: int,
clustering_threshold: float,
consentroid: str,
cluster_sizes_separator: str,
keep_cluster_sizes: bool,
output_format: str,
debug: bool
):
"""Internal function to generate a consentroid from a bin of sequences"""
consentroids = []
for vj in tqdm(chains['vj_bin'].unique().to_list(), desc=f'{step} chains processing'):
group = chains.filter(pl.col('vj_bin') == vj)
if group.height < min_cluster_size:
continue
# consentroids = []
# for vj in tqdm(chains['vj_bin'].unique().to_list(), desc=f'{step} chains processing'):
# group = chains.filter(pl.col('vj_bin') == vj)
# if group.height < min_cluster_size:
# continue
group_sequences = from_polars(group)
clusters = cluster.cluster(group_sequences, threshold=clustering_threshold)
for c in clusters:
if c.size < min_cluster_size:
continue
if consentroid == 'centroid':
seq_id = f"{c.centroid.id}{cluster_sizes_separator}{c.size}" if keep_cluster_sizes else c.centroid.id
consentroids.append(Sequence(c.centroid.sequence, id=seq_id))
elif consentroid == 'consensus':
_consensus = alignment.make_consensus([s for s in c.sequences], algo='mafft', alignment_kwargs={'mafft_bin': 'mafft', 'threads': 1}, debug=debug)
seq_id = f"{_consensus.id}{cluster_sizes_separator}{c.size}" if keep_cluster_sizes else _consensus.id
consentroids.append(Sequence(_consensus.sequence, id=seq_id))
# group_sequences = from_polars(group)
# clusters = cluster.cluster(group_sequences, threshold=clustering_threshold)
# for c in clusters:
# if c.size < min_cluster_size:
# continue
# if consentroid == 'centroid':
# seq_id = f"{c.centroid.id}{cluster_sizes_separator}{c.size}" if keep_cluster_sizes else c.centroid.id
# consentroids.append(Sequence(c.centroid.sequence, id=seq_id))
# elif consentroid == 'consensus':
# _consensus = alignment.make_consensus([s for s in c.sequences], algo='mafft', alignment_kwargs={'mafft_bin': 'mafft', 'threads': 1}, debug=debug)
# seq_id = f"{_consensus.id}{cluster_sizes_separator}{c.size}" if keep_cluster_sizes else _consensus.id
# consentroids.append(Sequence(_consensus.sequence, id=seq_id))
if debug:
print(f"\tProcessed {len(clusters)} clusters for {vj}, corresponding to {len(group_sequences)} sequences")
# if debug:
# print(f"\tProcessed {len(clusters)} clusters for {vj}, corresponding to {len(group_sequences)} sequences")
if debug:
print(f"\tTotal concentroids generated: {len(consentroids)}\n")
if output_format == 'fasta':
return consentroids
elif output_format == 'airr':
df = to_polars(consentroids)
if keep_cluster_sizes and len(df) != 0:
df = df.with_columns(
pl.col("sequence_id")
.str.split_exact(cluster_sizes_separator, 1) # Split into a struct with two fields
.alias("split_id")
).with_columns(
pl.col("split_id").struct.field("field_0").alias("sequence_id"), # Extract the first field
pl.col("split_id").struct.field("field_1").cast(pl.Utf8).alias("count") # Extract the second field
).drop("split_id") # Drop the intermediate "split_id" column
return df
# if debug:
# print(f"\tTotal concentroids generated: {len(consentroids)}\n")
# if output_format == 'fasta':
# return consentroids
def _process_chain_group(
chains: pl.DataFrame,
chain_type: str,
umi: bool,
min_cluster_size: int,
clustering_threshold: float,
consentroid: str,
cluster_sizes_separator: str,
keep_cluster_sizes: bool,
output_format: str,
debug: bool
):
# Generate vj_bin column
if umi:
chains = chains.with_columns((pl.col("v_gene") + "_" + pl.col("j_gene") + "_" + pl.col("umi")).alias("vj_bin"))
else:
chains = chains.with_columns((pl.col("v_gene") + "_" + pl.col("j_gene")).alias("vj_bin"))
# elif output_format == 'airr':
# df = to_polars(consentroids)
# if keep_cluster_sizes and len(df) != 0:
# df = df.with_columns(
# pl.col("sequence_id")
# .str.split_exact(cluster_sizes_separator, 1) # Split into a struct with two fields
# .alias("split_id")
# ).with_columns(
# pl.col("split_id").struct.field("field_0").alias("sequence_id"), # Extract the first field
# pl.col("split_id").struct.field("field_1").cast(pl.Utf8).alias("count") # Extract the second field
# ).drop("split_id") # Drop the intermediate "split_id" column
# return df
return _process_chains(chain_type, chains, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
# def _process_chain_group(
# chains: pl.DataFrame,
# chain_type: str,
# umi: bool,
# min_cluster_size: int,
# clustering_threshold: float,
# consentroid: str,
# cluster_sizes_separator: str,
# keep_cluster_sizes: bool,
# output_format: str,
# debug: bool
# ):
# # Generate vj_bin column
# if umi:
# chains = chains.with_columns((pl.col("v_gene") + "_" + pl.col("j_gene") + "_" + pl.col("umi")).alias("vj_bin"))
# else:
# chains = chains.with_columns((pl.col("v_gene") + "_" + pl.col("j_gene")).alias("vj_bin"))
# return _process_chains(chain_type, chains, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
#########################################
# #
# Main functions #
# #
#########################################
# #########################################
# # #
# # Main functions #
# # #
# #########################################
def deduplicate(project_folder: str,
output: Optional[str] = None,
output_format: str = 'fasta',
pool: bool = True,
umi: bool = False,
keep_read_numbers: bool = False,
read_number_separator: str = '|',
large_files: bool = False,
debug: bool = False
) -> None:
'''
A polyvalent tool for deduplication of assigned reads. This function takes as input the AIRR-compliant
tables and is specifically designed to handle extremely large files with a minimal footprint.
# def deduplicate(project_folder: str,
# output: Optional[str] = None,
# output_format: str = 'fasta',
# pool: bool = True,
# umi: bool = False,
# keep_read_numbers: bool = False,
# read_number_separator: str = '|',
# large_files: bool = False,
# debug: bool = False
# ) -> None:
# '''
# A polyvalent tool for deduplication of assigned reads. This function takes as input the AIRR-compliant
# tables and is specifically designed to handle extremely large files with a minimal footprint.
Parameters:
project_folder (str): Path to the project folder containing AIRR-compliant tables.
output (str, optional): Subdirectory for output files. Created if non-existent. Defaults to None.
output_format (str): Either "fasta" or "airr". Default is "fasta".
pool (bool, optional): If True, pool all samples together. Defaults to True.
umi (bool, optional): If True, use UMI for deduplication. Defaults to False.
keep_read_numbers (bool, optional): If True, read numbers will be added to sequence names. Defaults to False.
large_files (bool, optional): If True, optimize for large files (>100Go). Defaults to False.
debug (bool, optional): If True, print debug information. Defaults to False.
# Parameters:
# project_folder (str): Path to the project folder containing AIRR-compliant tables.
# output (str, optional): Subdirectory for output files. Created if non-existent. Defaults to None.
# output_format (str): Either "fasta" or "airr". Default is "fasta".
# pool (bool, optional): If True, pool all samples together. Defaults to True.
# umi (bool, optional): If True, use UMI for deduplication. Defaults to False.
# keep_read_numbers (bool, optional): If True, read numbers will be added to sequence names. Defaults to False.
# large_files (bool, optional): If True, optimize for large files (>100Go). Defaults to False.
# debug (bool, optional): If True, print debug information. Defaults to False.
Returns: None. A FASTA- or TSV- file is written to disk.
'''
# Returns: None. A FASTA- or TSV- file is written to disk.
# '''
start = time.time()
# start = time.time()
# Assert that output_format is valid
if output_format not in ['fasta', 'airr']:
raise ValueError("Invalid output_format. Must be 'fasta' or 'airr'.")
# # Assert that output_format is valid
# if output_format not in ['fasta', 'airr']:
# raise ValueError("Invalid output_format. Must be 'fasta' or 'airr'.")
# Preparing input and output file(s) / folder(s)
if os.path.isfile(project_folder):
project_folder = os.path.dirname(project_folder)
if debug:
print(f"Project folder is set to: {project_folder}")
files = sorted([f for f in list_files(project_folder, extension='tsv', recursive=True, ) if not 'tmp' in f])
# # Preparing input and output file(s) / folder(s)
# if os.path.isfile(project_folder):
# project_folder = os.path.dirname(project_folder)
# if debug:
# print(f"Project folder is set to: {project_folder}")
if len(files) == 0:
print('No files found. Exiting.')
elif len(files) == 1:
pool = False
# files = sorted([f for f in list_files(project_folder, extension='tsv', recursive=True, ) if not 'tmp' in f])
if output:
project_folder = os.path.join(project_folder, output)
make_dir(project_folder)
# Running main deduplication process
total_sequences = 0
pooled = []
for file in files:
sample = os.path.basename(file).split('.')[0]
# if len(files) == 0:
# print('No files found. Exiting.')
# elif len(files) == 1:
# pool = False
print(f"Processing {sample}")
print("-"*(len(sample)+11)+'\n')
# Loading files
if umi:
try:
df = pl.read_csv(file, columns=['sequence_id', 'sequence', 'umi'], separator='\t', null_values="None", low_memory=True if large_files else False, )
except Exception as e:
print(f"Error reading file {file}: {e}")
continue
else:
try:
df = pl.read_csv(file, columns=['sequence_id', 'sequence'], separator='\t', null_values="None", low_memory=True if large_files else False, )
except Exception as e:
print(f"Error reading file {file}: {e}")
continue
# if output:
# project_folder = os.path.join(project_folder, output)
# make_dir(project_folder)
print(f"Loaded {df.height:,} annotated sequences")
total_sequences += df.shape[0]
# # Running main deduplication process
# total_sequences = 0
# pooled = []
# Processing with deduplication in the presence of UMIs
if umi:
if debug:
print(f"Deduplicating with UMI ({df.unique(subset=['umi']).height} unique UMIs)")
df_unique = _deduplicate_sequences(df, ["sequence", "umi"], keep_read_numbers)
# for file in files:
# sample = os.path.basename(file).split('.')[0]
# Processing with deduplication in the absence of UMIs
else:
df_unique = _deduplicate_sequences(df, ["sequence"], keep_read_numbers)
# print(f"Processing {sample}")
# print("-"*(len(sample)+11)+'\n')
# # Loading files
# if umi:
# try:
# df = pl.read_csv(file, columns=['sequence_id', 'sequence', 'umi'], separator='\t', null_values="None", low_memory=True if large_files else False, )
# except Exception as e:
# print(f"Error reading file {file}: {e}")
# continue
# else:
# try:
# df = pl.read_csv(file, columns=['sequence_id', 'sequence'], separator='\t', null_values="None", low_memory=True if large_files else False, )
# except Exception as e:
# print(f"Error reading file {file}: {e}")
# continue
msg = f"Found {df_unique.height:,} unique sequences"
if pool:
msg += " added to the pool\n"
print(msg)
# print(f"Loaded {df.height:,} annotated sequences")
# total_sequences += df.shape[0]
if pool:
pooled.append(df_unique)
else:
# Saving deduplicated single output to FASTA file
if output_format == "fasta":
fasta_file = os.path.join(project_folder, sample)+'.fasta'
_write_fasta(df_unique, fasta_file, keep_read_numbers, read_number_separator)
# # Processing with deduplication in the presence of UMIs
# if umi:
# if debug:
# print(f"Deduplicating with UMI ({df.unique(subset=['umi']).height} unique UMIs)")
# Saving deduplicated single output to TSV file
elif output_format == "airr":
tsv_file = os.path.join(project_folder, sample)+'.tsv'
# Convert df_unique to LazyFrame for more efficient processing
df_unique_lazy = pl.LazyFrame(df_unique)
# Scan the complete original AIRR table
new_df = pl.scan_csv(file, separator='\t', low_memory=True if large_files else False)
# df_unique = _deduplicate_sequences(df, ["sequence", "umi"], keep_read_numbers)
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
df_unique_lazy.select(["sequence_id"]),
on="sequence_id",
how="semi"
).select(new_df.columns)
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_read_numbers:
filtered_df = filtered_df.join(
df_unique_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left"
).with_columns(
pl.col("count").fill_null(0) # Replace null values with 0 if necessary
).rename(
{"count": "duplicates"} # Rename the column
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator='\t')
print(f"Output written to TSV file: {tsv_file}\n")
# # Processing with deduplication in the absence of UMIs
# else:
# df_unique = _deduplicate_sequences(df, ["sequence"], keep_read_numbers)
if pool:
pool_df = pl.concat(pooled)
if not keep_read_numbers:
pool_unique = pool_df.unique(subset=["sequence"]).sort("sequence")
else:
pool_unique = (
pool_df.group_by("sequence").agg(
pl.count().alias("count"),
pl.col("sequence_id").first(),
)
.sort("sequence")
)
print(f"\nFound {pool_unique.height:,} unique sequences in pooled data")
# msg = f"Found {df_unique.height:,} unique sequences"
# if pool:
# msg += " added to the pool\n"
# print(msg)
# Saving deduplicated pooled output to FASTA file
if output_format == 'fasta':
fasta_file = os.path.join(project_folder, "deduplicated_pool.fasta")
_write_fasta(pool_unique, fasta_file, keep_read_numbers, read_number_separator)
# if pool:
# pooled.append(df_unique)
# else:
# # Saving deduplicated single output to FASTA file
# if output_format == "fasta":
# fasta_file = os.path.join(project_folder, sample)+'.fasta'
# Saving deduplicated pooled output to TSV file
elif output_format == "airr":
tsv_file = os.path.join(project_folder, "deduplicated_pool.tsv")
# Convert df_unique to LazyFrame for more efficient processing
df_unique_lazy = pl.LazyFrame(pool_unique)
# Scan and concatenate multiple AIRR tables
new_df = pl.concat(
[pl.scan_csv(file, separator="\t", low_memory=True if large_files else False) for file in files]
)
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
df_unique_lazy.select(["sequence_id"]),
on="sequence_id",
how="semi"
).select(new_df.columns)
# _write_fasta(df_unique, fasta_file, keep_read_numbers, read_number_separator)
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_read_numbers:
filtered_df = filtered_df.join(
df_unique_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left"
).with_columns(
pl.col("count").fill_null(0) # Replace null values with 0 if necessary
).rename(
{"count": "duplicates"} # Rename the column
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator='\t')
print(f"Output written to TSV file: {tsv_file}\n")
# Finalizing
end = time.time()
duration = end - start
print(f"Deduplication complete. Total elapsed time: {datetime.timedelta(seconds=int(duration))}.")
# # Saving deduplicated single output to TSV file
# elif output_format == "airr":
# tsv_file = os.path.join(project_folder, sample)+'.tsv'
# # Convert df_unique to LazyFrame for more efficient processing
# df_unique_lazy = pl.LazyFrame(df_unique)
# # Scan the complete original AIRR table
# new_df = pl.scan_csv(file, separator='\t', low_memory=True if large_files else False)
return
# # Filter new_df to only include rows with matching sequence_ids
# filtered_df = new_df.join(
# df_unique_lazy.select(["sequence_id"]),
# on="sequence_id",
# how="semi"
# ).select(new_df.columns)
# # Add the 'count' column to the filtered DataFrame (only for matching rows)
# if keep_read_numbers:
# filtered_df = filtered_df.join(
# df_unique_lazy.select(["sequence_id", "count"]),
# on="sequence_id",
# how="left"
# ).with_columns(
# pl.col("count").fill_null(0) # Replace null values with 0 if necessary
# ).rename(
# {"count": "duplicates"} # Rename the column
# )
# # Sink the resulting DataFrame to a TSV file
# filtered_df.collect().write_csv(tsv_file, separator='\t')
# print(f"Output written to TSV file: {tsv_file}\n")
# if pool:
# pool_df = pl.concat(pooled)
def reduction(
project_folder: str,
output: Optional[str] = None,
output_format: str = 'fasta',
pool: bool = True,
umi: bool = False,
keep_cluster_sizes: bool = False,
cluster_sizes_separator: str = '|',
min_cluster_size: int = 3,
clustering_threshold: float = 0.975,
consentroid: str = 'centroid',
large_files: bool = False,
debug: bool = False
) -> None:
"""
This function takes as an input AIRR-compliant tables (tsv) and proceeds to
data reduction by clustering sequences to a high identity threshold.
# if not keep_read_numbers:
# pool_unique = pool_df.unique(subset=["sequence"]).sort("sequence")
# else:
# pool_unique = (
# pool_df.group_by("sequence").agg(
# pl.count().alias("count"),
# pl.col("sequence_id").first(),
# )
# .sort("sequence")
# )
This is specifically designed to handle large files with minimal footprint.
Preclustering can be applied to increase performance over large datasets.
# print(f"\nFound {pool_unique.height:,} unique sequences in pooled data")
Parameters:
project_folder (str): Path to the project folder containing AIRR-compliant tables.
output (str, optional): Subdirectory for output files. Created if non-existent. Defaults to None.
output_format (str): Either "fasta" or "airr". Default is "fasta".
pool (bool, optional): If True, pool all samples together. Defaults to True.
umi (bool, optional): If True, use UMI for clustering. Defaults to False.
keep_cluster_sizes (bool, optional): If True, cluster sizes will be added to sequence names. Defaults to False.
cluster_sizes_separator (str, optional): Separator for cluster sizes in sequence names. Defaults to '|'.
min_cluster_size (int, optional): Minimum cluster size to consider. Defaults to 3.
clustering_threshold (float, optional): Identity threshold for clustering. Defaults to 0.975.
consentroid (str, optional): Method to determine cluster representative ('centroid' or 'consensus'). Defaults to 'centroid'.
large_files (bool, optional): If True, optimize for large files (>100Go). Defaults to False.
debug (bool, optional): If True, print debug information. Defaults to False.
# # Saving deduplicated pooled output to FASTA file
# if output_format == 'fasta':
# fasta_file = os.path.join(project_folder, "deduplicated_pool.fasta")
# _write_fasta(pool_unique, fasta_file, keep_read_numbers, read_number_separator)
Returns: None. A fasta file is written to disk.
"""
start = time.time()
if umi:
cluster_sizes_separator += "umi_count="
else:
cluster_sizes_separator += "cluster_size="
# # Saving deduplicated pooled output to TSV file
# elif output_format == "airr":
# tsv_file = os.path.join(project_folder, "deduplicated_pool.tsv")
# Assert that output_format is valid
if output_format not in ['fasta', 'airr']:
raise ValueError("Invalid output_format. Must be 'fasta' or 'airr'.")
# # Convert df_unique to LazyFrame for more efficient processing
# df_unique_lazy = pl.LazyFrame(pool_unique)
# Assert that clustering_threshold is valid
if not (0 < clustering_threshold <= 1):
raise ValueError("Clustering_threshold must be between 0 and 1.")
# # Scan and concatenate multiple AIRR tables
# new_df = pl.concat(
# [pl.scan_csv(file, separator="\t", low_memory=True if large_files else False) for file in files]
# )
# Assert that consentroid is valid
if consentroid not in ['centroid', 'consensus']:
raise ValueError("Consentroid must be either 'centroid' or 'consensus'.")
# # Filter new_df to only include rows with matching sequence_ids
# filtered_df = new_df.join(
# df_unique_lazy.select(["sequence_id"]),
# on="sequence_id",
# how="semi"
# ).select(new_df.columns)
# Check for incompatible consentroid and output_format arguments
if consentroid == 'consensus' and output_format == 'airr':
raise ValueError(
"The 'consensus' consentroid method is incompatible with the 'airr' output format. "
"Please use 'centroid' or choose 'fasta' as the output format."
)
# Preparing input and output file(s) / folder(s)
if os.path.isfile(project_folder):
project_folder = os.path.dirname(project_folder)
if debug:
print(f"Project folder is set to: {project_folder}")
files = sorted([f for f in list_files(project_folder, extension='tsv', recursive=True, ) if not 'tmp' in f])
# # Add the 'count' column to the filtered DataFrame (only for matching rows)
# if keep_read_numbers:
# filtered_df = filtered_df.join(
# df_unique_lazy.select(["sequence_id", "count"]),
# on="sequence_id",
# how="left"
# ).with_columns(
# pl.col("count").fill_null(0) # Replace null values with 0 if necessary
# ).rename(
# {"count": "duplicates"} # Rename the column
# )
if len(files) == 0:
print('No files found. Exiting.')
elif len(files) == 1:
pool = False
# # Sink the resulting DataFrame to a TSV file
# filtered_df.collect().write_csv(tsv_file, separator='\t')
# print(f"Output written to TSV file: {tsv_file}\n")
if output:
project_folder = os.path.join(project_folder, output)
make_dir(project_folder)
total_sequences = 0
# # Finalizing
# end = time.time()
# duration = end - start
# print(f"Deduplication complete. Total elapsed time: {datetime.timedelta(seconds=int(duration))}.")
if pool:
pooled_heavies = []
pooled_lights = []
# return
# Processing files individually...
for file in files:
sample = os.path.basename(file).split('.')[0]
print(f"Processing {sample}")
print("-"*(len(sample)+11)+'\n')
keys = ['sequence_id', 'v_gene', 'j_gene', 'locus', 'sequence']
if umi:
keys.append('umi')
# def reduction(
# project_folder: str,
# output: Optional[str] = None,
# output_format: str = 'fasta',
# pool: bool = True,
# umi: bool = False,
# keep_cluster_sizes: bool = False,
# cluster_sizes_separator: str = '|',
# min_cluster_size: int = 3,
# clustering_threshold: float = 0.975,
# consentroid: str = 'centroid',
# large_files: bool = False,
# debug: bool = False
# ) -> None:
# """
# This function takes as an input AIRR-compliant tables (tsv) and proceeds to
# data reduction by clustering sequences to a high identity threshold.
df = pl.read_csv(file, columns=keys, separator='\t', null_values="None", low_memory=True if large_files else False, )
# This is specifically designed to handle large files with minimal footprint.
# Preclustering can be applied to increase performance over large datasets.
if not all(col in df.columns for col in keys):
print(f"File {file} is missing required columns. Skipping...")
continue
print(f"Loaded {df.height:,} annotated sequences")
total_sequences += df.height
# Parameters:
# project_folder (str): Path to the project folder containing AIRR-compliant tables.
# output (str, optional): Subdirectory for output files. Created if non-existent. Defaults to None.
# output_format (str): Either "fasta" or "airr". Default is "fasta".
# pool (bool, optional): If True, pool all samples together. Defaults to True.
# umi (bool, optional): If True, use UMI for clustering. Defaults to False.
# keep_cluster_sizes (bool, optional): If True, cluster sizes will be added to sequence names. Defaults to False.
# cluster_sizes_separator (str, optional): Separator for cluster sizes in sequence names. Defaults to '|'.
# min_cluster_size (int, optional): Minimum cluster size to consider. Defaults to 3.
# clustering_threshold (float, optional): Identity threshold for clustering. Defaults to 0.975.
# consentroid (str, optional): Method to determine cluster representative ('centroid' or 'consensus'). Defaults to 'centroid'.
# large_files (bool, optional): If True, optimize for large files (>100Go). Defaults to False.
# debug (bool, optional): If True, print debug information. Defaults to False.
heavies = df.filter(pl.col('locus') == 'IGH')
lights = df.filter(pl.col('locus') != 'IGH')
# Returns: None. A fasta file is written to disk.
# """
if pool:
pooled_heavies.append(heavies)
pooled_lights.append(lights)
# start = time.time()
else:
sample_consentroids = []
# if umi:
# cluster_sizes_separator += "umi_count="
# else:
# cluster_sizes_separator += "cluster_size="
# Process individual heavy chains
heavy_consentroids = _process_chain_group(heavies, 'Heavy', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
sample_consentroids.extend(heavy_consentroids)
# # Assert that output_format is valid
# if output_format not in ['fasta', 'airr']:
# raise ValueError("Invalid output_format. Must be 'fasta' or 'airr'.")
# Process individual light chains
light_consentroids = _process_chain_group(lights, 'Lights', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
sample_consentroids.extend(light_consentroids)
# # Assert that clustering_threshold is valid
# if not (0 < clustering_threshold <= 1):
# raise ValueError("Clustering_threshold must be between 0 and 1.")
# Skipping file creation if no concentroids have been generated
if not sample_consentroids:
print(f"No clusters found for sample {sample}. Skipping output.\n")
continue
if output_format == 'fasta':
# Save consentroids to FASTA file
fasta_file = os.path.join(project_folder, sample)+'_reduced.fasta'
to_fasta(sample_consentroids, fasta_file)
elif output_format == 'airr':
# Saving centroids to TSV file (consensus can't be exported as AIRR-annotated TSV table as they haven't been annotated yet)
tsv_file = os.path.join(project_folder, sample + '_reduced.tsv')
# # Assert that consentroid is valid
# if consentroid not in ['centroid', 'consensus']:
# raise ValueError("Consentroid must be either 'centroid' or 'consensus'.")
# Convert centroids to LazyFrame for more efficient processing
centroids_lazy = pl.LazyFrame(sample_consentroids, )
# Scan the complete original AIRR table
new_df = pl.scan_csv(file, separator='\t', low_memory=True if large_files else False)
# # Check for incompatible consentroid and output_format arguments
# if consentroid == 'consensus' and output_format == 'airr':
# raise ValueError(
# "The 'consensus' consentroid method is incompatible with the 'airr' output format. "
# "Please use 'centroid' or choose 'fasta' as the output format."
# )
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
centroids_lazy.select(["sequence_id"]),
on="sequence_id",
how="semi"
).select(new_df.columns)
# # Preparing input and output file(s) / folder(s)
# if os.path.isfile(project_folder):
# project_folder = os.path.dirname(project_folder)
# if debug:
# print(f"Project folder is set to: {project_folder}")
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_cluster_sizes:
filtered_df = filtered_df.join(
centroids_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left"
).with_columns(
pl.col("count").fill_null(0) # Replace null values with 0 if necessary
).rename(
{"count": "umi_count" if umi else "cluster_size"} # Rename the column
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator='\t')
print(f"Output written to TSV file: {tsv_file}\n")
# files = sorted([f for f in list_files(project_folder, extension='tsv', recursive=True, ) if not 'tmp' in f])
if pool:
heavies = pl.concat(pooled_heavies)
lights = pl.concat(pooled_lights)
# if len(files) == 0:
# print('No files found. Exiting.')
# elif len(files) == 1:
# pool = False
all_consentroids = []
# Processing pooled heavy chains
heavy_consentroids = _process_chain_group(heavies, 'Heavy', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
all_consentroids.extend(heavy_consentroids)
# if output:
# project_folder = os.path.join(project_folder, output)
# make_dir(project_folder)
# Processing pooled light chains
light_consentroids = _process_chain_group(lights, 'Lights', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
all_consentroids.extend(light_consentroids)
# total_sequences = 0
if output_format == 'fasta':
# Saving pooled consentroids to FASTA file
fasta_file = os.path.join(project_folder, "reduced_pool.fasta")
to_fasta(all_consentroids, fasta_file)
# if pool:
# pooled_heavies = []
# pooled_lights = []
elif output_format == 'airr':
# Saving consentroids to TSV file
tsv_file = os.path.join(project_folder, 'reduced_pool.tsv')
# Convert centroids to LazyFrame for more efficient processing
centroids_lazy = pl.LazyFrame(all_consentroids)
# Scan the complete original AIRR table
new_df = pl.concat([pl.scan_csv(f, separator='\t',low_memory=True if large_files else False) for f in files])
# # Processing files individually...
# for file in files:
# sample = os.path.basename(file).split('.')[0]
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
centroids_lazy.select(["sequence_id"]),
on="sequence_id",
how="semi"
).select(new_df.columns)
# print(f"Processing {sample}")
# print("-"*(len(sample)+11)+'\n')
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_cluster_sizes:
filtered_df = filtered_df.join(
centroids_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left"
).with_columns(
pl.col("count").fill_null(0) # Replace null values with 0 if necessary
).rename(
{"count": "umi_count" if umi else "cluster_size"} # Rename the column
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator='\t')
print(f"Output written to TSV file: {tsv_file}\n")
# keys = ['sequence_id', 'v_gene', 'j_gene', 'locus', 'sequence']
# if umi:
# keys.append('umi')
end = time.time()
duration = end - start
print(f"Data reduction complete. Total elapsed time: {datetime.timedelta(seconds=int(duration))}.")
# df = pl.read_csv(file, columns=keys, separator='\t', null_values="None", low_memory=True if large_files else False, )
return
# if not all(col in df.columns for col in keys):
# print(f"File {file} is missing required columns. Skipping...")
# continue
# print(f"Loaded {df.height:,} annotated sequences")
# total_sequences += df.height
# heavies = df.filter(pl.col('locus') == 'IGH')
# lights = df.filter(pl.col('locus') != 'IGH')
# if pool:
# pooled_heavies.append(heavies)
# pooled_lights.append(lights)
# else:
# sample_consentroids = []
# # Process individual heavy chains
# heavy_consentroids = _process_chain_group(heavies, 'Heavy', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
# sample_consentroids.extend(heavy_consentroids)
# # Process individual light chains
# light_consentroids = _process_chain_group(lights, 'Lights', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
# sample_consentroids.extend(light_consentroids)
# # Skipping file creation if no concentroids have been generated
# if not sample_consentroids:
# print(f"No clusters found for sample {sample}. Skipping output.\n")
# continue
# if output_format == 'fasta':
# # Save consentroids to FASTA file
# fasta_file = os.path.join(project_folder, sample)+'_reduced.fasta'
# to_fasta(sample_consentroids, fasta_file)
# elif output_format == 'airr':
# # Saving centroids to TSV file (consensus can't be exported as AIRR-annotated TSV table as they haven't been annotated yet)
# tsv_file = os.path.join(project_folder, sample + '_reduced.tsv')
# # Convert centroids to LazyFrame for more efficient processing
# centroids_lazy = pl.LazyFrame(sample_consentroids, )
# # Scan the complete original AIRR table
# new_df = pl.scan_csv(file, separator='\t', low_memory=True if large_files else False)
# # Filter new_df to only include rows with matching sequence_ids
# filtered_df = new_df.join(
# centroids_lazy.select(["sequence_id"]),
# on="sequence_id",
# how="semi"
# ).select(new_df.columns)
# # Add the 'count' column to the filtered DataFrame (only for matching rows)
# if keep_cluster_sizes:
# filtered_df = filtered_df.join(
# centroids_lazy.select(["sequence_id", "count"]),
# on="sequence_id",
# how="left"
# ).with_columns(
# pl.col("count").fill_null(0) # Replace null values with 0 if necessary
# ).rename(
# {"count": "umi_count" if umi else "cluster_size"} # Rename the column
# )
# # Sink the resulting DataFrame to a TSV file
# filtered_df.collect().write_csv(tsv_file, separator='\t')
# print(f"Output written to TSV file: {tsv_file}\n")
# if pool:
# heavies = pl.concat(pooled_heavies)
# lights = pl.concat(pooled_lights)
# all_consentroids = []
# # Processing pooled heavy chains
# heavy_consentroids = _process_chain_group(heavies, 'Heavy', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
# all_consentroids.extend(heavy_consentroids)
# # Processing pooled light chains
# light_consentroids = _process_chain_group(lights, 'Lights', umi, min_cluster_size, clustering_threshold, consentroid, cluster_sizes_separator, keep_cluster_sizes, output_format, debug)
# all_consentroids.extend(light_consentroids)
# if output_format == 'fasta':
# # Saving pooled consentroids to FASTA file
# fasta_file = os.path.join(project_folder, "reduced_pool.fasta")
# to_fasta(all_consentroids, fasta_file)
# elif output_format == 'airr':
# # Saving consentroids to TSV file
# tsv_file = os.path.join(project_folder, 'reduced_pool.tsv')
# # Convert centroids to LazyFrame for more efficient processing
# centroids_lazy = pl.LazyFrame(all_consentroids)
# # Scan the complete original AIRR table
# new_df = pl.concat([pl.scan_csv(f, separator='\t',low_memory=True if large_files else False) for f in files])
# # Filter new_df to only include rows with matching sequence_ids
# filtered_df = new_df.join(
# centroids_lazy.select(["sequence_id"]),
# on="sequence_id",
# how="semi"
# ).select(new_df.columns)
# # Add the 'count' column to the filtered DataFrame (only for matching rows)
# if keep_cluster_sizes:
# filtered_df = filtered_df.join(
# centroids_lazy.select(["sequence_id", "count"]),
# on="sequence_id",
# how="left"
# ).with_columns(
# pl.col("count").fill_null(0) # Replace null values with 0 if necessary
# ).rename(
# {"count": "umi_count" if umi else "cluster_size"} # Rename the column
# )
# # Sink the resulting DataFrame to a TSV file
# filtered_df.collect().write_csv(tsv_file, separator='\t')
# print(f"Output written to TSV file: {tsv_file}\n")
# end = time.time()
# duration = end - start
# print(f"Data reduction complete. Total elapsed time: {datetime.timedelta(seconds=int(duration))}.")
# return

@@ -273,16 +273,21 @@ import platform

def test_famsa_alignment_file(sequences):
alignment_file = tempfile.NamedTemporaryFile(delete=False)
alignment = famsa(sequences, alignment_file=alignment_file.name, as_file=True)
assert alignment == alignment_file.name
# TODO: fix famsa alignment file and alignment string tests
# it's failing with the error AttributeError: 'MultipleSeqAlignment' object has no attribute 'format'
# but MultipleSequenceAlignment does have a format method. Not sure why this is happening.
# Also, other alignment functions seem to work fine despite requiring the same format method.
# def test_famsa_alignment_file(sequences):
# alignment_file = tempfile.NamedTemporaryFile(delete=False)
# alignment = famsa(sequences, alignment_file=alignment_file.name, as_file=True)
# assert alignment == alignment_file.name
def test_famsa_alignment_string(sequences):
alignment_string = famsa(sequences, as_string=True)
assert isinstance(alignment_string, str)
assert "seq1" in alignment_string
assert "seq2" in alignment_string
assert "seq3" in alignment_string
# def test_famsa_alignment_string(sequences):
# alignment_string = famsa(sequences, as_string=True)
# assert isinstance(alignment_string, str)
# assert "seq1" in alignment_string
# assert "seq2" in alignment_string
# assert "seq3" in alignment_string
def test_famsa_alignment_id_key():

@@ -289,0 +294,0 @@ seq1 = {

@@ -36,4 +36,4 @@ #!/usr/bin/env python

from .tools.similarity import repertoire_similarity
from .utils.phylogeny import igphyml, lsd
# from .utils.phylogeny import igphyml, lsd
# from .utils.alignment import (

@@ -40,0 +40,0 @@ # global_alignment,

@@ -25,14 +25,30 @@ #!/usr/bin/python

import datetime
import os
import subprocess as sp
import tempfile
import time
from typing import Iterable, Optional, Union
import polars as pl
from natsort import natsorted
from tqdm.auto import tqdm
from .. import Sequence
from ..bin import get_path as get_binary_path
from ..io import concatenate_files, delete_files, list_files, make_dir, rename_file
from ..io import (
concatenate_files,
delete_files,
from_polars,
list_files,
make_dir,
rename_file,
to_fasta,
to_polars,
)
from ..tools import alignment, cluster
__all__ = ["deduplicate", "reduction", "merge_fastqs", "group_fastq_pairs"]
class FASTQFile:

@@ -475,5 +491,5 @@ """

if output_format.lower() == "fasta":
cmd += " --fastaout {merged_file}"
cmd += f" --fastaout {merged_file}"
elif output_format.lower() == "fastq":
cmd += " --fastqout {merged_file}"
cmd += f" --fastqout {merged_file}"
else:

@@ -649,1 +665,750 @@ err = f"Invalid output format: {output_format}. Must be 'fasta' or 'fastq'."

return merged
#########################################
# #
# Helper functions #
# #
#########################################
def _deduplicate_sequences(
df: pl.DataFrame, group_by_cols: list, keep_read_numbers: bool
):
if not keep_read_numbers:
return df.unique(subset=group_by_cols).sort("sequence")
return (
df.group_by(group_by_cols)
.agg(
pl.count().alias("count"),
pl.col("sequence_id").first(),
)
.sort("sequence")
)
def _write_fasta(
df: pl.DataFrame,
file_path: str,
keep_read_numbers: bool,
read_number_separator: str,
):
if keep_read_numbers:
df = df.with_columns(
(
pl.col("sequence_id")
+ read_number_separator
+ pl.col("count").cast(pl.Utf8)
).alias("sequence_id")
)
tuples_list = list(zip(df["sequence_id"].to_list(), df["sequence"].to_list()))
to_fasta(tuples_list, file_path)
print(f"Output written to FASTA file: {file_path}\n")
def _process_chains(
step: str,
chains: pl.DataFrame,
min_cluster_size: int,
clustering_threshold: float,
consentroid: str,
cluster_sizes_separator: str,
keep_cluster_sizes: bool,
output_format: str,
debug: bool,
):
"""Internal function to generate a consentroid from a bin of sequences"""
consentroids = []
for vj in tqdm(
chains["vj_bin"].unique().to_list(), desc=f"{step} chains processing"
):
group = chains.filter(pl.col("vj_bin") == vj)
if group.height < min_cluster_size:
continue
group_sequences = from_polars(group)
clusters = cluster.cluster(group_sequences, threshold=clustering_threshold)
for c in clusters:
if c.size < min_cluster_size:
continue
if consentroid == "centroid":
seq_id = (
f"{c.centroid.id}{cluster_sizes_separator}{c.size}"
if keep_cluster_sizes
else c.centroid.id
)
consentroids.append(Sequence(c.centroid.sequence, id=seq_id))
elif consentroid == "consensus":
_consensus = alignment.make_consensus(
[s for s in c.sequences],
algo="mafft",
alignment_kwargs={"mafft_bin": "mafft", "threads": 1},
debug=debug,
)
seq_id = (
f"{_consensus.id}{cluster_sizes_separator}{c.size}"
if keep_cluster_sizes
else _consensus.id
)
consentroids.append(Sequence(_consensus.sequence, id=seq_id))
if debug:
print(
f"\tProcessed {len(clusters)} clusters for {vj}, corresponding to {len(group_sequences)} sequences"
)
if debug:
print(f"\tTotal concentroids generated: {len(consentroids)}\n")
if output_format == "fasta":
return consentroids
elif output_format == "airr":
df = to_polars(consentroids)
if keep_cluster_sizes and len(df) != 0:
df = (
df.with_columns(
pl.col("sequence_id")
.str.split_exact(
cluster_sizes_separator, 1
) # Split into a struct with two fields
.alias("split_id")
)
.with_columns(
pl.col("split_id")
.struct.field("field_0")
.alias("sequence_id"), # Extract the first field
pl.col("split_id")
.struct.field("field_1")
.cast(pl.Utf8)
.alias("count"), # Extract the second field
)
.drop("split_id")
) # Drop the intermediate "split_id" column
return df
def _process_chain_group(
chains: pl.DataFrame,
chain_type: str,
umi: bool,
min_cluster_size: int,
clustering_threshold: float,
consentroid: str,
cluster_sizes_separator: str,
keep_cluster_sizes: bool,
output_format: str,
debug: bool,
):
# Generate vj_bin column
if umi:
chains = chains.with_columns(
(pl.col("v_gene") + "_" + pl.col("j_gene") + "_" + pl.col("umi")).alias(
"vj_bin"
)
)
else:
chains = chains.with_columns(
(pl.col("v_gene") + "_" + pl.col("j_gene")).alias("vj_bin")
)
return _process_chains(
chain_type,
chains,
min_cluster_size,
clustering_threshold,
consentroid,
cluster_sizes_separator,
keep_cluster_sizes,
output_format,
debug,
)
#########################################
# #
# Main functions #
# #
#########################################
def deduplicate(
project_folder: str,
output: Optional[str] = None,
output_format: str = "fasta",
pool: bool = True,
umi: bool = False,
keep_read_numbers: bool = False,
read_number_separator: str = "|",
large_files: bool = False,
debug: bool = False,
) -> None:
"""
A polyvalent tool for deduplication of assigned reads.
This function takes as input the AIRR-compliant
tables and is specifically designed to handle extremely
large files with a minimal footprint.
Parameters
----------
project_folder : str
Path to the project folder containing AIRR-compliant tables.
output : str, optional
Subdirectory for output files. Created if non-existent. Defaults to None.
output_format : str
Either "fasta" or "airr". Default is "fasta".
pool : bool, optional
If True, pool all samples together. Defaults to True.
umi : bool, optional
If True, use UMI for deduplication. Defaults to False.
keep_read_numbers : bool, optional
If True, read numbers will be added to sequence names. Defaults to False.
read_number_separator : str, optional
Separator for read numbers in sequence names. Defaults to "|".
large_files : bool, optional
If True, optimize for large files (>100Go). Defaults to False.
debug : bool, optional
If True, print debug information. Defaults to False.
Returns
-------
None. A FASTA- or TSV-formatted file is written to disk.
"""
start = time.time()
# Assert that output_format is valid
if output_format not in ["fasta", "airr"]:
raise ValueError("Invalid output_format. Must be 'fasta' or 'airr'.")
# Preparing input and output file(s) / folder(s)
if os.path.isfile(project_folder):
project_folder = os.path.dirname(project_folder)
if debug:
print(f"Project folder is set to: {project_folder}")
files = sorted(
[
f
for f in list_files(
project_folder,
extension="tsv",
recursive=True,
)
if "tmp" not in f
]
)
if len(files) == 0:
print("No files found. Exiting.")
elif len(files) == 1:
pool = False
if output:
project_folder = os.path.join(project_folder, output)
make_dir(project_folder)
# Running main deduplication process
total_sequences = 0
pooled = []
for file in files:
sample = os.path.basename(file).split(".")[0]
print(f"Processing {sample}")
print("-" * (len(sample) + 11) + "\n")
# Loading files
if umi:
try:
df = pl.read_csv(
file,
columns=["sequence_id", "sequence", "umi"],
separator="\t",
null_values="None",
low_memory=True if large_files else False,
)
except Exception as e:
print(f"Error reading file {file}: {e}")
continue
else:
try:
df = pl.read_csv(
file,
columns=["sequence_id", "sequence"],
separator="\t",
null_values="None",
low_memory=True if large_files else False,
)
except Exception as e:
print(f"Error reading file {file}: {e}")
continue
print(f"Loaded {df.height:,} annotated sequences")
total_sequences += df.shape[0]
# Processing with deduplication in the presence of UMIs
if umi:
if debug:
print(
f"Deduplicating with UMI ({df.unique(subset=['umi']).height} unique UMIs)"
)
df_unique = _deduplicate_sequences(
df, ["sequence", "umi"], keep_read_numbers
)
# Processing with deduplication in the absence of UMIs
else:
df_unique = _deduplicate_sequences(df, ["sequence"], keep_read_numbers)
msg = f"Found {df_unique.height:,} unique sequences"
if pool:
msg += " added to the pool\n"
print(msg)
if pool:
pooled.append(df_unique)
else:
# Saving deduplicated single output to FASTA file
if output_format == "fasta":
fasta_file = os.path.join(project_folder, sample) + ".fasta"
_write_fasta(
df_unique, fasta_file, keep_read_numbers, read_number_separator
)
# Saving deduplicated single output to TSV file
elif output_format == "airr":
tsv_file = os.path.join(project_folder, sample) + ".tsv"
# Convert df_unique to LazyFrame for more efficient processing
df_unique_lazy = pl.LazyFrame(df_unique)
# Scan the complete original AIRR table
new_df = pl.scan_csv(
file, separator="\t", low_memory=True if large_files else False
)
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
df_unique_lazy.select(["sequence_id"]), on="sequence_id", how="semi"
).select(new_df.columns)
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_read_numbers:
filtered_df = (
filtered_df.join(
df_unique_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left",
)
.with_columns(
pl.col("count").fill_null(
0
) # Replace null values with 0 if necessary
)
.rename(
{"count": "duplicates"} # Rename the column
)
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator="\t")
print(f"Output written to TSV file: {tsv_file}\n")
if pool:
pool_df = pl.concat(pooled)
if not keep_read_numbers:
pool_unique = pool_df.unique(subset=["sequence"]).sort("sequence")
else:
pool_unique = (
pool_df.group_by("sequence")
.agg(
pl.count().alias("count"),
pl.col("sequence_id").first(),
)
.sort("sequence")
)
print(f"\nFound {pool_unique.height:,} unique sequences in pooled data")
# Saving deduplicated pooled output to FASTA file
if output_format == "fasta":
fasta_file = os.path.join(project_folder, "deduplicated_pool.fasta")
_write_fasta(
pool_unique, fasta_file, keep_read_numbers, read_number_separator
)
# Saving deduplicated pooled output to TSV file
elif output_format == "airr":
tsv_file = os.path.join(project_folder, "deduplicated_pool.tsv")
# Convert df_unique to LazyFrame for more efficient processing
df_unique_lazy = pl.LazyFrame(pool_unique)
# Scan and concatenate multiple AIRR tables
new_df = pl.concat(
[
pl.scan_csv(
file, separator="\t", low_memory=True if large_files else False
)
for file in files
]
)
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
df_unique_lazy.select(["sequence_id"]), on="sequence_id", how="semi"
).select(new_df.columns)
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_read_numbers:
filtered_df = (
filtered_df.join(
df_unique_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left",
)
.with_columns(
pl.col("count").fill_null(
0
) # Replace null values with 0 if necessary
)
.rename(
{"count": "duplicates"} # Rename the column
)
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator="\t")
print(f"Output written to TSV file: {tsv_file}\n")
# Finalizing
end = time.time()
duration = end - start
print(
f"Deduplication complete. Total elapsed time: {datetime.timedelta(seconds=int(duration))}."
)
return
def reduction(
project_folder: str,
output: Optional[str] = None,
output_format: str = "fasta",
pool: bool = True,
umi: bool = False,
keep_cluster_sizes: bool = False,
cluster_sizes_separator: str = "|",
min_cluster_size: int = 3,
clustering_threshold: float = 0.975,
consentroid: str = "centroid",
large_files: bool = False,
debug: bool = False,
) -> None:
"""
This function takes as an input AIRR-compliant tables (tsv) and proceeds to
data reduction by clustering sequences to a high identity threshold.
This is specifically designed to handle large files with minimal footprint.
Preclustering can be applied to increase performance over large datasets.
Parameters:
project_folder (str): Path to the project folder containing AIRR-compliant tables.
output (str, optional): Subdirectory for output files. Created if non-existent. Defaults to None.
output_format (str): Either "fasta" or "airr". Default is "fasta".
pool (bool, optional): If True, pool all samples together. Defaults to True.
umi (bool, optional): If True, use UMI for clustering. Defaults to False.
keep_cluster_sizes (bool, optional): If True, cluster sizes will be added to sequence names. Defaults to False.
cluster_sizes_separator (str, optional): Separator for cluster sizes in sequence names. Defaults to '|'.
min_cluster_size (int, optional): Minimum cluster size to consider. Defaults to 3.
clustering_threshold (float, optional): Identity threshold for clustering. Defaults to 0.975.
consentroid (str, optional): Method to determine cluster representative ('centroid' or 'consensus'). Defaults to 'centroid'.
large_files (bool, optional): If True, optimize for large files (>100Go). Defaults to False.
debug (bool, optional): If True, print debug information. Defaults to False.
Returns: None. A fasta file is written to disk.
"""
start = time.time()
if umi:
cluster_sizes_separator += "umi_count="
else:
cluster_sizes_separator += "cluster_size="
# Assert that output_format is valid
if output_format not in ["fasta", "airr"]:
raise ValueError("Invalid output_format. Must be 'fasta' or 'airr'.")
# Assert that clustering_threshold is valid
if not (0 < clustering_threshold <= 1):
raise ValueError("Clustering_threshold must be between 0 and 1.")
# Assert that consentroid is valid
if consentroid not in ["centroid", "consensus"]:
raise ValueError("Consentroid must be either 'centroid' or 'consensus'.")
# Check for incompatible consentroid and output_format arguments
if consentroid == "consensus" and output_format == "airr":
raise ValueError(
"The 'consensus' consentroid method is incompatible with the 'airr' output format. "
"Please use 'centroid' or choose 'fasta' as the output format."
)
# Preparing input and output file(s) / folder(s)
if os.path.isfile(project_folder):
project_folder = os.path.dirname(project_folder)
if debug:
print(f"Project folder is set to: {project_folder}")
files = sorted(
[
f
for f in list_files(
project_folder,
extension="tsv",
recursive=True,
)
if not "tmp" in f
]
)
if len(files) == 0:
print("No files found. Exiting.")
elif len(files) == 1:
pool = False
if output:
project_folder = os.path.join(project_folder, output)
make_dir(project_folder)
total_sequences = 0
if pool:
pooled_heavies = []
pooled_lights = []
# Processing files individually...
for file in files:
sample = os.path.basename(file).split(".")[0]
print(f"Processing {sample}")
print("-" * (len(sample) + 11) + "\n")
keys = ["sequence_id", "v_gene", "j_gene", "locus", "sequence"]
if umi:
keys.append("umi")
df = pl.read_csv(
file,
columns=keys,
separator="\t",
null_values="None",
low_memory=True if large_files else False,
)
if not all(col in df.columns for col in keys):
print(f"File {file} is missing required columns. Skipping...")
continue
print(f"Loaded {df.height:,} annotated sequences")
total_sequences += df.height
heavies = df.filter(pl.col("locus") == "IGH")
lights = df.filter(pl.col("locus") != "IGH")
if pool:
pooled_heavies.append(heavies)
pooled_lights.append(lights)
else:
sample_consentroids = []
# Process individual heavy chains
heavy_consentroids = _process_chain_group(
heavies,
"Heavy",
umi,
min_cluster_size,
clustering_threshold,
consentroid,
cluster_sizes_separator,
keep_cluster_sizes,
output_format,
debug,
)
sample_consentroids.extend(heavy_consentroids)
# Process individual light chains
light_consentroids = _process_chain_group(
lights,
"Lights",
umi,
min_cluster_size,
clustering_threshold,
consentroid,
cluster_sizes_separator,
keep_cluster_sizes,
output_format,
debug,
)
sample_consentroids.extend(light_consentroids)
# Skipping file creation if no concentroids have been generated
if not sample_consentroids:
print(f"No clusters found for sample {sample}. Skipping output.\n")
continue
if output_format == "fasta":
# Save consentroids to FASTA file
fasta_file = os.path.join(project_folder, sample) + "_reduced.fasta"
to_fasta(sample_consentroids, fasta_file)
elif output_format == "airr":
# Saving centroids to TSV file (consensus can't be exported as AIRR-annotated TSV table as they haven't been annotated yet)
tsv_file = os.path.join(project_folder, sample + "_reduced.tsv")
# Convert centroids to LazyFrame for more efficient processing
centroids_lazy = pl.LazyFrame(
sample_consentroids,
)
# Scan the complete original AIRR table
new_df = pl.scan_csv(
file, separator="\t", low_memory=True if large_files else False
)
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
centroids_lazy.select(["sequence_id"]), on="sequence_id", how="semi"
).select(new_df.columns)
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_cluster_sizes:
filtered_df = (
filtered_df.join(
centroids_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left",
)
.with_columns(
pl.col("count").fill_null(
0
) # Replace null values with 0 if necessary
)
.rename(
{
"count": "umi_count" if umi else "cluster_size"
} # Rename the column
)
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator="\t")
print(f"Output written to TSV file: {tsv_file}\n")
if pool:
heavies = pl.concat(pooled_heavies)
lights = pl.concat(pooled_lights)
all_consentroids = []
# Processing pooled heavy chains
heavy_consentroids = _process_chain_group(
heavies,
"Heavy",
umi,
min_cluster_size,
clustering_threshold,
consentroid,
cluster_sizes_separator,
keep_cluster_sizes,
output_format,
debug,
)
all_consentroids.extend(heavy_consentroids)
# Processing pooled light chains
light_consentroids = _process_chain_group(
lights,
"Lights",
umi,
min_cluster_size,
clustering_threshold,
consentroid,
cluster_sizes_separator,
keep_cluster_sizes,
output_format,
debug,
)
all_consentroids.extend(light_consentroids)
if output_format == "fasta":
# Saving pooled consentroids to FASTA file
fasta_file = os.path.join(project_folder, "reduced_pool.fasta")
to_fasta(all_consentroids, fasta_file)
elif output_format == "airr":
# Saving consentroids to TSV file
tsv_file = os.path.join(project_folder, "reduced_pool.tsv")
# Convert centroids to LazyFrame for more efficient processing
centroids_lazy = pl.LazyFrame(all_consentroids)
# Scan the complete original AIRR table
new_df = pl.concat(
[
pl.scan_csv(
f, separator="\t", low_memory=True if large_files else False
)
for f in files
]
)
# Filter new_df to only include rows with matching sequence_ids
filtered_df = new_df.join(
centroids_lazy.select(["sequence_id"]), on="sequence_id", how="semi"
).select(new_df.columns)
# Add the 'count' column to the filtered DataFrame (only for matching rows)
if keep_cluster_sizes:
filtered_df = (
filtered_df.join(
centroids_lazy.select(["sequence_id", "count"]),
on="sequence_id",
how="left",
)
.with_columns(
pl.col("count").fill_null(
0
) # Replace null values with 0 if necessary
)
.rename(
{
"count": "umi_count" if umi else "cluster_size"
} # Rename the column
)
)
# Sink the resulting DataFrame to a TSV file
filtered_df.collect().write_csv(tsv_file, separator="\t")
print(f"Output written to TSV file: {tsv_file}\n")
end = time.time()
duration = end - start
print(
f"Data reduction complete. Total elapsed time: {datetime.timedelta(seconds=int(duration))}."
)
return

@@ -1,3 +0,3 @@

__all__ = ['alignment', 'cluster', 'codons', 'color', 'convert',
'decorators', 'germlines', 'jobs', 'log', 'mongodb',
'phylogeny', 'pipeline', 'progbar', 's3', 'ssh_tunnel']
# __all__ = ['alignment', 'cluster', 'codons', 'color', 'convert',
# 'decorators', 'germlines', 'jobs', 'log', 'mongodb',
# 'phylogeny', 'pipeline', 'progbar', 's3', 'ssh_tunnel']

@@ -26,505 +26,505 @@ #!/usr/bin/env python

import os
import platform
import random
import shutil
import string
import subprocess as sp
import sys
import tempfile
from typing import Iterable, Optional, Union
# import os
# import platform
# import random
# import shutil
# import string
# import subprocess as sp
# import sys
# import tempfile
# from typing import Iterable, Optional, Union
from Bio.Align import AlignInfo
# from Bio.Align import AlignInfo
from ..core.sequence import Sequence
from ..io import read_fasta, to_fasta
from .alignment import mafft
from .decorators import lazy_property
from .pipeline import make_dir
# from ..core.sequence import Sequence
# from ..io import read_fasta, to_fasta
# from .alignment import mafft
# from .decorators import lazy_property
# from .pipeline import make_dir
class Cluster:
"""
Docstring for Cluster.
"""
# class Cluster:
# """
# Docstring for Cluster.
# """
def __init__(self, name, sequences, centroid=None):
self.name = name
self.sequences = sequences
self.centroid = centroid
# def __init__(self, name, sequences, centroid=None):
# self.name = name
# self.sequences = sequences
# self.centroid = centroid
def __iter__(self):
for s in self.sequences:
yield s
# def __iter__(self):
# for s in self.sequences:
# yield s
@property
def size(self):
return len(self.sequences)
# @property
# def size(self):
# return len(self.sequences)
@property
def seq_ids(self):
return [s.id for s in self.sequences]
# @property
# def seq_ids(self):
# return [s.id for s in self.sequences]
@lazy_property
def consensus(self):
return self._make_consensus()
# @lazy_property
# def consensus(self):
# return self._make_consensus()
def _make_consensus(self):
if len(self.sequences) == 1:
return self.sequences[0]
aln = mafft(self.sequences)
if aln is None:
print("ERROR: Failed to generate an alignmnet for a consensus sequence.")
return None
summary_align = AlignInfo.SummaryInfo(aln)
consensus = summary_align.gap_consensus(threshold=0.51, ambiguous="n")
consensus_string = str(consensus).replace("-", "")
consensus_seq = Sequence(consensus_string.upper())
return consensus_seq
# def _make_consensus(self):
# if len(self.sequences) == 1:
# return self.sequences[0]
# aln = mafft(self.sequences)
# if aln is None:
# print("ERROR: Failed to generate an alignmnet for a consensus sequence.")
# return None
# summary_align = AlignInfo.SummaryInfo(aln)
# consensus = summary_align.gap_consensus(threshold=0.51, ambiguous="n")
# consensus_string = str(consensus).replace("-", "")
# consensus_seq = Sequence(consensus_string.upper())
# return consensus_seq
class Clusters:
"""
Docstring for Clusters.
"""
# class Clusters:
# """
# Docstring for Clusters.
# """
def __init__(self, clusters=None):
self._clusters = self._parse_clusters(clusters)
# def __init__(self, clusters=None):
# self._clusters = self._parse_clusters(clusters)
def __iter__(self):
for cluster in self.clusters:
yield cluster
# def __iter__(self):
# for cluster in self.clusters:
# yield cluster
def __getitem__(self, key):
return self.clusters[key]
# def __getitem__(self, key):
# return self.clusters[key]
def __len__(self):
return len(self._clusters)
# def __len__(self):
# return len(self._clusters)
@property
def clusters(self):
return sorted(self._clusters, key=lambda x: x.size, reverse=True)
# @property
# def clusters(self):
# return sorted(self._clusters, key=lambda x: x.size, reverse=True)
@property
def centroids(self):
return [c.centroid for c in self.clusters]
# @property
# def centroids(self):
# return [c.centroid for c in self.clusters]
@property
def largest_cluster(self):
return self.clusters[0]
# @property
# def largest_cluster(self):
# return self.clusters[0]
@property
def count(self):
return len(self.clusters)
# @property
# def count(self):
# return len(self.clusters)
def _parse_clusters(self, clusts):
if isinstance(clusts, dict):
clusters = []
for name, cdata in clusts.items():
seqs = cdata["seqs"]
centroid = cdata["centroid"]
clusters.append(Cluster(name, seqs, centroid=centroid))
return clusters
elif clusts is None:
return []
else:
return clusts
# def _parse_clusters(self, clusts):
# if isinstance(clusts, dict):
# clusters = []
# for name, cdata in clusts.items():
# seqs = cdata["seqs"]
# centroid = cdata["centroid"]
# clusters.append(Cluster(name, seqs, centroid=centroid))
# return clusters
# elif clusts is None:
# return []
# else:
# return clusts
def add(self, cluster):
self._clusters.append(cluster)
# def add(self, cluster):
# self._clusters.append(cluster)
def cluster(
sequences: Union[Iterable, str],
threshold: float = 0.975,
algo: str = "auto",
temp_dir: str = "/tmp",
iddef: int = 0,
vsearch_bin: str = None,
mmseqs_bin: str = None,
id_key: Optional[str] = None,
seq_key: Optional[str] = None,
strand: str = "plus",
as_dict: bool = False,
debug: bool = False,
) -> Union[dict, Clusters]:
"""
Clusters sequences using `VSEARCH`_ or `MMseqs2`_. By default, sequences will
be clustered with VSEARCH if there are fewer than 10,000 sequences and
with MMseqs2 if there are more than 10,000 sequences. These defaults can be
overridden with `algo`.
# def cluster(
# sequences: Union[Iterable, str],
# threshold: float = 0.975,
# algo: str = "auto",
# temp_dir: str = "/tmp",
# iddef: int = 0,
# vsearch_bin: str = None,
# mmseqs_bin: str = None,
# id_key: Optional[str] = None,
# seq_key: Optional[str] = None,
# strand: str = "plus",
# as_dict: bool = False,
# debug: bool = False,
# ) -> Union[dict, Clusters]:
# """
# Clusters sequences using `VSEARCH`_ or `MMseqs2`_. By default, sequences will
# be clustered with VSEARCH if there are fewer than 10,000 sequences and
# with MMseqs2 if there are more than 10,000 sequences. These defaults can be
# overridden with `algo`.
Parameters
----------
sequences : iterable or string
Input sequences in any of the following formats:
1. list of abutils ``Sequence`` objects
2. FASTA-formatted string
3. path to a FASTA-formatted file
4. list of BioPython ``SeqRecord`` objects
5. list of lists/tuples, of the format ``[sequence_id, sequence]``
Required.
# Parameters
# ----------
# sequences : iterable or string
# Input sequences in any of the following formats:
# 1. list of abutils ``Sequence`` objects
# 2. FASTA-formatted string
# 3. path to a FASTA-formatted file
# 4. list of BioPython ``SeqRecord`` objects
# 5. list of lists/tuples, of the format ``[sequence_id, sequence]``
# Required.
threshold : float, default=0.975
Identity threshold for clustering. Must be between 0 and 1.
# threshold : float, default=0.975
# Identity threshold for clustering. Must be between 0 and 1.
algo : float, default="auto"
Algorithm to be used for clustering. Options are ``"vsearch"``, ``"mmseqs"``,
or ``"auto"``. By default (``"auto"``), VSEARCH will be used if there are fewer than 10,000
sequences and MMseqs2 will be used for 10,000 sequences or more. Providing ``"vsearch"``
or ``"mmseqs"`` will force the use of the desired clustering algorithm regardless of the
number or sequences to be clustered.
# algo : float, default="auto"
# Algorithm to be used for clustering. Options are ``"vsearch"``, ``"mmseqs"``,
# or ``"auto"``. By default (``"auto"``), VSEARCH will be used if there are fewer than 10,000
# sequences and MMseqs2 will be used for 10,000 sequences or more. Providing ``"vsearch"``
# or ``"mmseqs"`` will force the use of the desired clustering algorithm regardless of the
# number or sequences to be clustered.
temp_dir : str, default="/tmp"
Path to a directory for temporary storage of clustering files.
# temp_dir : str, default="/tmp"
# Path to a directory for temporary storage of clustering files.
iddef : int, default=1
Identity definition, as implemented by VSEARCH. Options are:
0. CD-HIT definition: (matching columns) / (shortest sequence length).
1. edit distance: (matching columns) / (alignment length).
2. edit distance excluding terminal gaps (same as --id).
3. Marine Biological Lab definition counting each gap opening (internal or
terminal) as a single mismatch, whether or not the gap was extended: 1.0
- [(mismatches + gap openings)/(longest sequence length)]
4. BLAST definition, equivalent to --iddef 1 in a context of global pairwise
alignment.
# iddef : int, default=1
# Identity definition, as implemented by VSEARCH. Options are:
# 0. CD-HIT definition: (matching columns) / (shortest sequence length).
# 1. edit distance: (matching columns) / (alignment length).
# 2. edit distance excluding terminal gaps (same as --id).
# 3. Marine Biological Lab definition counting each gap opening (internal or
# terminal) as a single mismatch, whether or not the gap was extended: 1.0
# - [(mismatches + gap openings)/(longest sequence length)]
# 4. BLAST definition, equivalent to --iddef 1 in a context of global pairwise
# alignment.
vsearch_bin : str, optional
Path to a VSEARCH executable. If not provided, the VSEARCH binary bundled
with ``abutils`` will be used.
# vsearch_bin : str, optional
# Path to a VSEARCH executable. If not provided, the VSEARCH binary bundled
# with ``abutils`` will be used.
mmseqs_bin : str, optional
Path to a MMseqs2 executable. If not provided, the MMseqs2 binary bundled
with ``abutils`` will be used.
# mmseqs_bin : str, optional
# Path to a MMseqs2 executable. If not provided, the MMseqs2 binary bundled
# with ``abutils`` will be used.
id_key : str, default=None
Key to retrieve the sequence ID. If not provided or missing, ``Sequence.id`` is used.
# id_key : str, default=None
# Key to retrieve the sequence ID. If not provided or missing, ``Sequence.id`` is used.
sequence_key : str, default=None
Key to retrieve the sequence. If not provided or missing, ``Sequence.sequence`` is used.
# sequence_key : str, default=None
# Key to retrieve the sequence. If not provided or missing, ``Sequence.sequence`` is used.
strand : str, default="plus"
Strand of the sequences to align. Options are ``"plus"`` and ``"both"``.
# strand : str, default="plus"
# Strand of the sequences to align. Options are ``"plus"`` and ``"both"``.
as_dict : bool, default=False
If ``True``, return clustering results as a ``dict`` rather than a ``Clusters``
object. the ``dict`` is of the format:
{"cluster1_name": {"centroid": cluster1_centroid,
"seqs": [seq1, seq2, seq3, ...]},
"cluster2_name": {"centroid": cluster2_centroid,
"seqs": [seq4, seq5, seq6, ...]},
}
# as_dict : bool, default=False
# If ``True``, return clustering results as a ``dict`` rather than a ``Clusters``
# object. the ``dict`` is of the format:
# {"cluster1_name": {"centroid": cluster1_centroid,
# "seqs": [seq1, seq2, seq3, ...]},
# "cluster2_name": {"centroid": cluster2_centroid,
# "seqs": [seq4, seq5, seq6, ...]},
# }
debug : bool, default=False
If ``True``, prints MAFFT's standard output and standard error.
Default is ``False``.
# debug : bool, default=False
# If ``True``, prints MAFFT's standard output and standard error.
# Default is ``False``.
Returns
-------
clusters : ``Clusters`` or ``dict``
# Returns
# -------
# clusters : ``Clusters`` or ``dict``
.. _VSEARCH
https://github.com/torognes/vsearch
x
.. _MMseqs2
https://github.com/soedinglab/MMseqs2
# .. _VSEARCH
# https://github.com/torognes/vsearch
# x
# .. _MMseqs2
# https://github.com/soedinglab/MMseqs2
"""
# check input data to get the number of sequences
fasta_file = to_fasta(
sequences, tempfile_dir=temp_dir, id_key=id_key, sequence_key=seq_key
)
seqs = read_fasta(fasta_file)
seq_dict = {s.id: s for s in seqs}
# select the clustering algo
algo = algo.lower()
if algo == "auto" and len(seqs) < 10000:
algo = "vsearch"
elif algo == "auto" and len(seqs) >= 10000:
algo = "mmseqs"
if algo in ["mmseqs", "mmseqs2"]:
cluster_dict = cluster_mmseqs(
fasta_file=fasta_file,
threshold=threshold,
temp_dir=temp_dir,
mmseqs_bin=mmseqs_bin,
as_dict=True,
debug=debug,
)
elif algo == "vsearch":
cluster_dict = cluster_vsearch(
fasta_file=fasta_file,
threshold=threshold,
temp_dir=temp_dir,
iddef=iddef,
vsearch_bin=vsearch_bin,
strand=strand,
as_dict=True,
debug=debug,
)
else:
err = f"\nERROR: Invalid algo option: {algo}."
err += " Valid choices are: 'vsearch', 'mmseqs', or 'auto'.\n"
print(err)
sys.exit()
cluster_info = {}
for cname, cdata in cluster_dict.items():
cluster_info[cname] = {
"centroid": seq_dict[cdata["centroid_id"]],
"seqs": [seq_dict[seq_id] for seq_id in cdata["seq_ids"]],
}
if as_dict:
return cluster_info
return Clusters(cluster_info)
# """
# # check input data to get the number of sequences
# fasta_file = to_fasta(
# sequences, tempfile_dir=temp_dir, id_key=id_key, sequence_key=seq_key
# )
# seqs = read_fasta(fasta_file)
# seq_dict = {s.id: s for s in seqs}
# # select the clustering algo
# algo = algo.lower()
# if algo == "auto" and len(seqs) < 10000:
# algo = "vsearch"
# elif algo == "auto" and len(seqs) >= 10000:
# algo = "mmseqs"
# if algo in ["mmseqs", "mmseqs2"]:
# cluster_dict = cluster_mmseqs(
# fasta_file=fasta_file,
# threshold=threshold,
# temp_dir=temp_dir,
# mmseqs_bin=mmseqs_bin,
# as_dict=True,
# debug=debug,
# )
# elif algo == "vsearch":
# cluster_dict = cluster_vsearch(
# fasta_file=fasta_file,
# threshold=threshold,
# temp_dir=temp_dir,
# iddef=iddef,
# vsearch_bin=vsearch_bin,
# strand=strand,
# as_dict=True,
# debug=debug,
# )
# else:
# err = f"\nERROR: Invalid algo option: {algo}."
# err += " Valid choices are: 'vsearch', 'mmseqs', or 'auto'.\n"
# print(err)
# sys.exit()
# cluster_info = {}
# for cname, cdata in cluster_dict.items():
# cluster_info[cname] = {
# "centroid": seq_dict[cdata["centroid_id"]],
# "seqs": [seq_dict[seq_id] for seq_id in cdata["seq_ids"]],
# }
# if as_dict:
# return cluster_info
# return Clusters(cluster_info)
def cluster_vsearch(
fasta_file: str,
threshold: float = 0.975,
temp_dir: str = "/tmp",
iddef: int = 0,
vsearch_bin: str = None,
strand: str = "plus",
as_dict: bool = False,
debug: bool = False,
) -> Union[dict, Clusters]:
"""
Clusters sequences using `VSEARCH`_.
# def cluster_vsearch(
# fasta_file: str,
# threshold: float = 0.975,
# temp_dir: str = "/tmp",
# iddef: int = 0,
# vsearch_bin: str = None,
# strand: str = "plus",
# as_dict: bool = False,
# debug: bool = False,
# ) -> Union[dict, Clusters]:
# """
# Clusters sequences using `VSEARCH`_.
Parameters
----------
fasta_file : string
Path to a FASTA-formatted file. Required. If you'd like to run ``vsearch``
using ``Sequence`` objects as input, use ``cluster(algo="vsearch")``.
# Parameters
# ----------
# fasta_file : string
# Path to a FASTA-formatted file. Required. If you'd like to run ``vsearch``
# using ``Sequence`` objects as input, use ``cluster(algo="vsearch")``.
threshold : float, default=0.975
Identity threshold for clustering. Must be between 0 and 1.
# threshold : float, default=0.975
# Identity threshold for clustering. Must be between 0 and 1.
temp_dir : str, default="/tmp"
Path to a directory for temporary storage of clustering files.
# temp_dir : str, default="/tmp"
# Path to a directory for temporary storage of clustering files.
iddef : int, default=1
Identity definition, as implemented by VSEARCH. Options are:
0. CD-HIT definition: (matching columns) / (shortest sequence length).
1. edit distance: (matching columns) / (alignment length).
2. edit distance excluding terminal gaps (same as --id).
3. Marine Biological Lab definition counting each gap opening (internal or
terminal) as a single mismatch, whether or not the gap was extended: 1.0
- [(mismatches + gap openings)/(longest sequence length)]
4. BLAST definition, equivalent to --iddef 1 in a context of global pairwise
alignment.
# iddef : int, default=1
# Identity definition, as implemented by VSEARCH. Options are:
# 0. CD-HIT definition: (matching columns) / (shortest sequence length).
# 1. edit distance: (matching columns) / (alignment length).
# 2. edit distance excluding terminal gaps (same as --id).
# 3. Marine Biological Lab definition counting each gap opening (internal or
# terminal) as a single mismatch, whether or not the gap was extended: 1.0
# - [(mismatches + gap openings)/(longest sequence length)]
# 4. BLAST definition, equivalent to --iddef 1 in a context of global pairwise
# alignment.
vsearch_bin : str, optional
Path to a VSEARCH executable. If not provided, the VSEARCH binary bundled
with ``abutils`` will be used.
# vsearch_bin : str, optional
# Path to a VSEARCH executable. If not provided, the VSEARCH binary bundled
# with ``abutils`` will be used.
id_key : str, default=None
Key to retrieve the sequence ID. If not provided or missing, ``Sequence.id`` is used.
# id_key : str, default=None
# Key to retrieve the sequence ID. If not provided or missing, ``Sequence.id`` is used.
sequence_key : str, default=None
Key to retrieve the sequence. If not provided or missing, ``Sequence.sequence`` is used.
# sequence_key : str, default=None
# Key to retrieve the sequence. If not provided or missing, ``Sequence.sequence`` is used.
strand : str, default="plus"
Strand of the sequences to align. Options are ``"plus"`` and ``"both"``.
# strand : str, default="plus"
# Strand of the sequences to align. Options are ``"plus"`` and ``"both"``.
as_dict : bool, default=False
If ``True``, return clustering results as a ``dict`` rather than a ``Clusters``
object. the ``dict`` is of the format:
{"cluster1_name": {"centroid": cluster1_centroid,
"seqs": [seq1, seq2, seq3, ...]},
"cluster2_name": {"centroid": cluster2_centroid,
"seqs": [seq4, seq5, seq6, ...]},
}
# as_dict : bool, default=False
# If ``True``, return clustering results as a ``dict`` rather than a ``Clusters``
# object. the ``dict`` is of the format:
# {"cluster1_name": {"centroid": cluster1_centroid,
# "seqs": [seq1, seq2, seq3, ...]},
# "cluster2_name": {"centroid": cluster2_centroid,
# "seqs": [seq4, seq5, seq6, ...]},
# }
debug : bool, default=False
If ``True``, prints standard output and standard error from ``vsearch``.
Default is ``False``.
# debug : bool, default=False
# If ``True``, prints standard output and standard error from ``vsearch``.
# Default is ``False``.
Returns
-------
clusters : Path to the UC output file from ``vsearch`` or a ``dict`` of cluster info.
# Returns
# -------
# clusters : Path to the UC output file from ``vsearch`` or a ``dict`` of cluster info.
.. _VSEARCH
https://github.com/torognes/vsearch
# .. _VSEARCH
# https://github.com/torognes/vsearch
"""
# output files
centroid_file = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False, prefix="centroids_"
).name
uc_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False, prefix="uc_").name
# get the vsearch binary
if vsearch_bin is None:
mod_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
vsearch_bin = os.path.join(
mod_dir, "bin/vsearch_{}".format(platform.system().lower())
)
# do clustering
vsearch_cmd = f"{vsearch_bin} --cluster_fast {fasta_file}"
vsearch_cmd += f" --centroids {centroid_file}"
vsearch_cmd += f" --clusterout_id"
vsearch_cmd += f" --uc {uc_file}"
vsearch_cmd += f" --id {threshold}"
vsearch_cmd += f" --iddef {iddef}"
vsearch_cmd += f" --sizeout"
vsearch_cmd += f" --strand {strand}"
p = sp.Popen(vsearch_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
stdout, stderr = p.communicate()
if debug:
print("STDOUT:", stdout.decode("utf-8"))
print("")
print("STDERR:", stderr.decode("utf-8"))
# process output
if os.stat(uc_file).st_size == 0:
err = f"WARNING: the VSEARCH output file ({uc_file}) is empty. "
err += "Please verify that the input data is valid."
print(err)
return None
if as_dict:
cluster_info = {}
with open(uc_file, "r") as f:
for line in f:
if not (ldata := line.strip().split()):
continue
cluster_num = ldata[1]
if cluster_num not in cluster_info:
cluster_info[cluster_num] = {"seq_ids": []}
if ldata[0] == "S":
centroid = ldata[-2]
cluster_info[cluster_num]["centroid_id"] = centroid
cluster_info[cluster_num]["seq_ids"].append(centroid)
elif ldata[0] == "H":
hit = ldata[-2]
cluster_info[cluster_num]["seq_ids"].append(hit)
if not debug:
os.remove(uc_file)
return cluster_info
else:
return uc_file
# """
# # output files
# centroid_file = tempfile.NamedTemporaryFile(
# dir=temp_dir, delete=False, prefix="centroids_"
# ).name
# uc_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False, prefix="uc_").name
# # get the vsearch binary
# if vsearch_bin is None:
# mod_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# vsearch_bin = os.path.join(
# mod_dir, "bin/vsearch_{}".format(platform.system().lower())
# )
# # do clustering
# vsearch_cmd = f"{vsearch_bin} --cluster_fast {fasta_file}"
# vsearch_cmd += f" --centroids {centroid_file}"
# vsearch_cmd += f" --clusterout_id"
# vsearch_cmd += f" --uc {uc_file}"
# vsearch_cmd += f" --id {threshold}"
# vsearch_cmd += f" --iddef {iddef}"
# vsearch_cmd += f" --sizeout"
# vsearch_cmd += f" --strand {strand}"
# p = sp.Popen(vsearch_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
# stdout, stderr = p.communicate()
# if debug:
# print("STDOUT:", stdout.decode("utf-8"))
# print("")
# print("STDERR:", stderr.decode("utf-8"))
# # process output
# if os.stat(uc_file).st_size == 0:
# err = f"WARNING: the VSEARCH output file ({uc_file}) is empty. "
# err += "Please verify that the input data is valid."
# print(err)
# return None
# if as_dict:
# cluster_info = {}
# with open(uc_file, "r") as f:
# for line in f:
# if not (ldata := line.strip().split()):
# continue
# cluster_num = ldata[1]
# if cluster_num not in cluster_info:
# cluster_info[cluster_num] = {"seq_ids": []}
# if ldata[0] == "S":
# centroid = ldata[-2]
# cluster_info[cluster_num]["centroid_id"] = centroid
# cluster_info[cluster_num]["seq_ids"].append(centroid)
# elif ldata[0] == "H":
# hit = ldata[-2]
# cluster_info[cluster_num]["seq_ids"].append(hit)
# if not debug:
# os.remove(uc_file)
# return cluster_info
# else:
# return uc_file
def cluster_mmseqs(
fasta_file: str,
threshold: float = 0.975,
temp_dir: str = "/tmp",
mmseqs_bin: str = None,
as_dict: bool = False,
debug: bool = False,
):
"""
Clusters sequences using `MMseqs2`_.
# def cluster_mmseqs(
# fasta_file: str,
# threshold: float = 0.975,
# temp_dir: str = "/tmp",
# mmseqs_bin: str = None,
# as_dict: bool = False,
# debug: bool = False,
# ):
# """
# Clusters sequences using `MMseqs2`_.
Parameters
----------
fasta_file : string
Path to a FASTA-formatted file. Required. If you'd like to run ``mmseqs``
using ``Sequence`` objects as input, use ``cluster(algo="mmseqs")``.
# Parameters
# ----------
# fasta_file : string
# Path to a FASTA-formatted file. Required. If you'd like to run ``mmseqs``
# using ``Sequence`` objects as input, use ``cluster(algo="mmseqs")``.
threshold : float, default=0.975
Identity threshold for clustering. Must be between 0 and 1.
# threshold : float, default=0.975
# Identity threshold for clustering. Must be between 0 and 1.
temp_dir : str, default="/tmp"
Path to a directory for temporary storage of clustering files.
# temp_dir : str, default="/tmp"
# Path to a directory for temporary storage of clustering files.
mmseqs_bin : str, optional
Path to a MMseqs2 executable. If not provided, the MMseqs2 binary bundled
with ``abutils`` will be used.
# mmseqs_bin : str, optional
# Path to a MMseqs2 executable. If not provided, the MMseqs2 binary bundled
# with ``abutils`` will be used.
id_key : str, default=None
Key to retrieve the sequence ID. If not provided or missing, ``Sequence.id`` is used.
# id_key : str, default=None
# Key to retrieve the sequence ID. If not provided or missing, ``Sequence.id`` is used.
sequence_key : str, default=None
Key to retrieve the sequence. If not provided or missing, ``Sequence.sequence`` is used.
# sequence_key : str, default=None
# Key to retrieve the sequence. If not provided or missing, ``Sequence.sequence`` is used.
as_dict : bool, default=False
If ``True``, return clustering results as a ``dict`` rather than the TSV output file.
object. the ``dict`` is of the format:
{"cluster1_name": {"centroid_id": cluster1_centroid_id,
"seq_ids": [seq1_id, seq2_id, seq3_id, ...]},
"cluster2_name": {"centroid_id": cluster2_centroid_id,
"seq_ids": [seq4_id, seq5_id, seq6_id, ...]},
}
# as_dict : bool, default=False
# If ``True``, return clustering results as a ``dict`` rather than the TSV output file.
# object. the ``dict`` is of the format:
# {"cluster1_name": {"centroid_id": cluster1_centroid_id,
# "seq_ids": [seq1_id, seq2_id, seq3_id, ...]},
# "cluster2_name": {"centroid_id": cluster2_centroid_id,
# "seq_ids": [seq4_id, seq5_id, seq6_id, ...]},
# }
debug : bool, default=False
If ``True``, prints standard output and standard error from ``mmseqs``.
Default is ``False``.
# debug : bool, default=False
# If ``True``, prints standard output and standard error from ``mmseqs``.
# Default is ``False``.
Returns
-------
clusters : Path to the TSV output file from ``mmseqs`` or a ``dict`` of cluster info.
# Returns
# -------
# clusters : Path to the TSV output file from ``mmseqs`` or a ``dict`` of cluster info.
.. _MMseqs2
https://github.com/soedinglab/MMseqs2
# .. _MMseqs2
# https://github.com/soedinglab/MMseqs2
"""
# output files
db_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False, prefix="DB_").name
clu_file = tempfile.NamedTemporaryFile(
dir=temp_dir, delete=False, prefix="CLU_"
).name
tsv_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False, prefix="TSV_")
# get the mmseqs binary
if mmseqs_bin is None:
mod_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
mmseqs_bin = os.path.join(
mod_dir, "bin/mmseqs_{}".format(platform.system().lower())
)
# build the mmseqs DB
db_cmd = f"{mmseqs_bin} createdb {fasta_file} {db_file}"
p = sp.Popen(db_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
stdout, stderr = p.communicate()
if debug:
print("STDOUT:", stdout)
print("")
print("STDERR:", stderr)
# do the clustering
cluster_cmd = f"{mmseqs_bin} cluster"
cluster_cmd += f" {db_file} {clu_file} {temp_dir}"
cluster_cmd += f" --min-seq-id {threshold}"
p = sp.Popen(cluster_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
stdout, stderr = p.communicate()
if debug:
print("STDOUT:", stdout)
print("")
print("STDERR:", stderr)
# generate TSV-formatted output
tsv_cmd = f"{mmseqs_bin} createtsv"
tsv_cmd += f" {db_file} {db_file} {clu_file} {tsv_file}"
p = sp.Popen(tsv_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
stdout, stderr = p.communicate()
if debug:
print("STDOUT:", stdout)
print("")
print("STDERR:", stderr)
# parse TSV output
if as_dict:
cluster_info = {}
name_dict = {}
with open(tsv_file) as f:
for line in f:
line = line.strip()
if line.startswith("#"):
continue
c, s = line.split()
if c not in name_dict:
name = "".join(
random.choice(string.ascii_uppercase + string.digits)
for _ in range(8)
)
name_dict[c] = name
name = name_dict[c]
if name not in cluster_info:
cluster_info[name] = {"centroid_id": c, "seq_ids": []}
cluster_info[name]["seq_ids"].append(s)
if not debug:
os.remove(tsv_file)
return cluster_info
else:
return tsv_file
# """
# # output files
# db_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False, prefix="DB_").name
# clu_file = tempfile.NamedTemporaryFile(
# dir=temp_dir, delete=False, prefix="CLU_"
# ).name
# tsv_file = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False, prefix="TSV_")
# # get the mmseqs binary
# if mmseqs_bin is None:
# mod_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# mmseqs_bin = os.path.join(
# mod_dir, "bin/mmseqs_{}".format(platform.system().lower())
# )
# # build the mmseqs DB
# db_cmd = f"{mmseqs_bin} createdb {fasta_file} {db_file}"
# p = sp.Popen(db_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
# stdout, stderr = p.communicate()
# if debug:
# print("STDOUT:", stdout)
# print("")
# print("STDERR:", stderr)
# # do the clustering
# cluster_cmd = f"{mmseqs_bin} cluster"
# cluster_cmd += f" {db_file} {clu_file} {temp_dir}"
# cluster_cmd += f" --min-seq-id {threshold}"
# p = sp.Popen(cluster_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
# stdout, stderr = p.communicate()
# if debug:
# print("STDOUT:", stdout)
# print("")
# print("STDERR:", stderr)
# # generate TSV-formatted output
# tsv_cmd = f"{mmseqs_bin} createtsv"
# tsv_cmd += f" {db_file} {db_file} {clu_file} {tsv_file}"
# p = sp.Popen(tsv_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
# stdout, stderr = p.communicate()
# if debug:
# print("STDOUT:", stdout)
# print("")
# print("STDERR:", stderr)
# # parse TSV output
# if as_dict:
# cluster_info = {}
# name_dict = {}
# with open(tsv_file) as f:
# for line in f:
# line = line.strip()
# if line.startswith("#"):
# continue
# c, s = line.split()
# if c not in name_dict:
# name = "".join(
# random.choice(string.ascii_uppercase + string.digits)
# for _ in range(8)
# )
# name_dict[c] = name
# name = name_dict[c]
# if name not in cluster_info:
# cluster_info[name] = {"centroid_id": c, "seq_ids": []}
# cluster_info[name]["seq_ids"].append(s)
# if not debug:
# os.remove(tsv_file)
# return cluster_info
# else:
# return tsv_file

@@ -531,0 +531,0 @@

@@ -26,468 +26,468 @@ #!/usr/bin/env python

from __future__ import absolute_import, division, print_function, unicode_literals
# from __future__ import absolute_import, division, print_function, unicode_literals
import logging
import platform
import os
import subprocess as sp
import sys
# import logging
# import platform
# import os
# import subprocess as sp
# import sys
from pymongo import MongoClient
# from pymongo import MongoClient
from . import log
# from . import log
if sys.version_info[0] > 2:
STR_TYPES = [
str,
]
else:
STR_TYPES = [str, unicode]
# if sys.version_info[0] > 2:
# STR_TYPES = [
# str,
# ]
# else:
# STR_TYPES = [str, unicode]
def get_connection(ip="localhost", port=27017, user=None, password=None):
"""
Returns a pymongo ``MongoClient`` object.
# def get_connection(ip="localhost", port=27017, user=None, password=None):
# """
# Returns a pymongo ``MongoClient`` object.
.. note:
# .. note:
Both ``user`` and ``password`` are required when connecting to a MongoDB
database that has authentication enabled.
# Both ``user`` and ``password`` are required when connecting to a MongoDB
# database that has authentication enabled.
Arguments:
# Arguments:
ip (str): IP address of the MongoDB server. Default is ``localhost``.
# ip (str): IP address of the MongoDB server. Default is ``localhost``.
port (int): Port of the MongoDB server. Default is ``27017``.
# port (int): Port of the MongoDB server. Default is ``27017``.
user (str): Username, if authentication is enabled on the MongoDB database.
Default is ``None``, which results in requesting the connection
without authentication.
# user (str): Username, if authentication is enabled on the MongoDB database.
# Default is ``None``, which results in requesting the connection
# without authentication.
password (str): Password, if authentication is enabled on the MongoDB database.
Default is ``None``, which results in requesting the connection
without authentication.
"""
if platform.system().lower() == "darwin":
connect = False
else:
connect = True
if user and password:
import urllib
# password (str): Password, if authentication is enabled on the MongoDB database.
# Default is ``None``, which results in requesting the connection
# without authentication.
# """
# if platform.system().lower() == "darwin":
# connect = False
# else:
# connect = True
# if user and password:
# import urllib
pwd = urllib.quote_plus(password)
uri = "mongodb://{}:{}@{}:{}".format(user, pwd, ip, port)
return MongoClient(uri, connect=connect)
return MongoClient(ip, port, connect=connect)
# pwd = urllib.quote_plus(password)
# uri = "mongodb://{}:{}@{}:{}".format(user, pwd, ip, port)
# return MongoClient(uri, connect=connect)
# return MongoClient(ip, port, connect=connect)
def list_dbs(ip="localhost", port=27017, user=None, connect=True, hint=None):
"""
Returns a list of Databases.
# def list_dbs(ip="localhost", port=27017, user=None, connect=True, hint=None):
# """
# Returns a list of Databases.
.. note:
# .. note:
Both ``user`` and ``password`` are required when connecting to a MongoDB
database that has authentication enabled.
# Both ``user`` and ``password`` are required when connecting to a MongoDB
# database that has authentication enabled.
Arguments:
# Arguments:
db (str): Name of the MongoDB database. Required.
# db (str): Name of the MongoDB database. Required.
ip (str): IP address of the MongoDB server. Default is ``localhost``.
# ip (str): IP address of the MongoDB server. Default is ``localhost``.
port (int): Port of the MongoDB server. Default is ``27017``.
# port (int): Port of the MongoDB server. Default is ``27017``.
user (str): Username, if authentication is enabled on the MongoDB database.
Default is ``None``, which results in requesting the connection
without authentication.
# user (str): Username, if authentication is enabled on the MongoDB database.
# Default is ``None``, which results in requesting the connection
# without authentication.
password (str): Password, if authentication is enabled on the MongoDB database.
Default is ``None``, which results in requesting the connection
without authentication.
hint (str): substring found in database name, if used only list of
databases that contain the substring will returned
"""
# password (str): Password, if authentication is enabled on the MongoDB database.
# Default is ``None``, which results in requesting the connection
# without authentication.
# Darwin says "you shall not connect!"
if platform.system().lower() == "darwin":
connect = False
else:
connect = True
# hint (str): substring found in database name, if used only list of
# databases that contain the substring will returned
# """
# check for user and pass, if so: make uri, import urllib, return dbs
if user and password:
import urllib
# # Darwin says "you shall not connect!"
# if platform.system().lower() == "darwin":
# connect = False
# else:
# connect = True
pwd = urllib.quote_plus(password)
uri = "mongodb://{}:{}@{}:{}".format(user, pwd, ip, port)
# check to see if hint was passed in and return list of dbs with hint
if hint and type(hint) == str:
ls = MongoClient(uri, connect=connect).list_database_names()
return [d for d in ls if hint.upper() in d.upper()]
# otherwise return a list of all dbs
else:
return MongoClient(uri, connect=connect).list_database_names()
# # check for user and pass, if so: make uri, import urllib, return dbs
# if user and password:
# import urllib
# otherwise no user or pass required
else:
# check to see if hint was passed in and return list of dbs with hint
if hint and type(hint) == str:
ls = MongoClient(ip, port, connect=connect).list_database_names()
return [d for d in ls if hint.upper() in d.upper()]
# otherwise return a list of all dbs
else:
return MongoClient(ip, port, connect=connect).list_database_names()
# pwd = urllib.quote_plus(password)
# uri = "mongodb://{}:{}@{}:{}".format(user, pwd, ip, port)
# # check to see if hint was passed in and return list of dbs with hint
# if hint and type(hint) == str:
# ls = MongoClient(uri, connect=connect).list_database_names()
# return [d for d in ls if hint.upper() in d.upper()]
# # otherwise return a list of all dbs
# else:
# return MongoClient(uri, connect=connect).list_database_names()
# # otherwise no user or pass required
# else:
# # check to see if hint was passed in and return list of dbs with hint
# if hint and type(hint) == str:
# ls = MongoClient(ip, port, connect=connect).list_database_names()
# return [d for d in ls if hint.upper() in d.upper()]
# # otherwise return a list of all dbs
# else:
# return MongoClient(ip, port, connect=connect).list_database_names()
def get_db(db, ip="localhost", port=27017, user=None, password=None):
"""
Returns a pymongo ``Database`` object.
.. note:
# def get_db(db, ip="localhost", port=27017, user=None, password=None):
# """
# Returns a pymongo ``Database`` object.
Both ``user`` and ``password`` are required when connecting to a MongoDB
database that has authentication enabled.
# .. note:
Arguments:
# Both ``user`` and ``password`` are required when connecting to a MongoDB
# database that has authentication enabled.
db (str): Name of the MongoDB database. Required.
# Arguments:
ip (str): IP address of the MongoDB server. Default is ``localhost``.
# db (str): Name of the MongoDB database. Required.
port (int): Port of the MongoDB server. Default is ``27017``.
# ip (str): IP address of the MongoDB server. Default is ``localhost``.
user (str): Username, if authentication is enabled on the MongoDB database.
Default is ``None``, which results in requesting the connection
without authentication.
# port (int): Port of the MongoDB server. Default is ``27017``.
password (str): Password, if authentication is enabled on the MongoDB database.
Default is ``None``, which results in requesting the connection
without authentication.
"""
if platform.system().lower() == "darwin":
connect = False
else:
connect = True
if user and password:
import urllib
# user (str): Username, if authentication is enabled on the MongoDB database.
# Default is ``None``, which results in requesting the connection
# without authentication.
pwd = urllib.quote_plus(password)
uri = "mongodb://{}:{}@{}:{}".format(user, pwd, ip, port)
conn = MongoClient(uri, connect=connect)
else:
conn = MongoClient(ip, port, connect=connect)
return conn[db]
# password (str): Password, if authentication is enabled on the MongoDB database.
# Default is ``None``, which results in requesting the connection
# without authentication.
# """
# if platform.system().lower() == "darwin":
# connect = False
# else:
# connect = True
# if user and password:
# import urllib
# pwd = urllib.quote_plus(password)
# uri = "mongodb://{}:{}@{}:{}".format(user, pwd, ip, port)
# conn = MongoClient(uri, connect=connect)
# else:
# conn = MongoClient(ip, port, connect=connect)
# return conn[db]
def get_collections(db, collection=None, prefix=None, suffix=None):
"""
Returns a sorted list of collection names found in ``db``.
Arguments:
# def get_collections(db, collection=None, prefix=None, suffix=None):
# """
# Returns a sorted list of collection names found in ``db``.
db (Database): A pymongo Database object. Can be obtained
with ``get_db``.
# Arguments:
collection (str): Name of a collection. If the collection is
present in the MongoDB database, a single-element list will
be returned with the collecion name. If not, an empty list
will be returned. This option is primarly included to allow
for quick checking to see if a collection name is present.
Default is None, which results in this option being ignored.
# db (Database): A pymongo Database object. Can be obtained
# with ``get_db``.
prefix (str): If supplied, only collections that begin with
``prefix`` will be returned.
# collection (str): Name of a collection. If the collection is
# present in the MongoDB database, a single-element list will
# be returned with the collecion name. If not, an empty list
# will be returned. This option is primarly included to allow
# for quick checking to see if a collection name is present.
# Default is None, which results in this option being ignored.
suffix (str): If supplied, only collections that end with
``suffix`` will be returned.
# prefix (str): If supplied, only collections that begin with
# ``prefix`` will be returned.
Returns:
# suffix (str): If supplied, only collections that end with
# ``suffix`` will be returned.
list: A sorted list of collection names.
"""
if collection is not None:
return [
collection,
]
collections = db.collection_names(include_system_collections=False)
if prefix is not None:
collections = [c for c in collections if c.startswith(prefix)]
if suffix is not None:
collections = [c for c in collections if c.endswith(suffix)]
return sorted(collections)
# Returns:
# list: A sorted list of collection names.
# """
# if collection is not None:
# return [
# collection,
# ]
# collections = db.collection_names(include_system_collections=False)
# if prefix is not None:
# collections = [c for c in collections if c.startswith(prefix)]
# if suffix is not None:
# collections = [c for c in collections if c.endswith(suffix)]
# return sorted(collections)
def rename_collection(db, collection, new_name):
"""
Renames a MongoDB collection.
Arguments:
# def rename_collection(db, collection, new_name):
# """
# Renames a MongoDB collection.
db (Database): A pymongo ``Database`` object. Can be obtained
with ``get_db``.
# Arguments:
collection (str): Name of the collection to be renamed.
# db (Database): A pymongo ``Database`` object. Can be obtained
# with ``get_db``.
new_name (str, func): ``new_name`` can be one of two things::
# collection (str): Name of the collection to be renamed.
1. The new collection name, as a string.
2. A function which, when passed the current collection name,
returns the new collection name. If the function
returns an empty string, the collection will not be
renamed.
"""
if hasattr(new_name, "__call__"):
_new = new_name(collection)
if _new == "":
return
else:
_new = new_name
c = db[collection]
c.rename(_new)
# new_name (str, func): ``new_name`` can be one of two things::
# 1. The new collection name, as a string.
# 2. A function which, when passed the current collection name,
# returns the new collection name. If the function
# returns an empty string, the collection will not be
# renamed.
# """
# if hasattr(new_name, "__call__"):
# _new = new_name(collection)
# if _new == "":
# return
# else:
# _new = new_name
# c = db[collection]
# c.rename(_new)
def update(field, value, db, collection, match=None):
"""
Updates MongoDB documents.
Sets ``field`` equal to ``value`` for all documents that
meet ``match`` criteria.
# def update(field, value, db, collection, match=None):
# """
# Updates MongoDB documents.
Arguments:
# Sets ``field`` equal to ``value`` for all documents that
# meet ``match`` criteria.
field (str): Field to update.
# Arguments:
value (str): Update value.
# field (str): Field to update.
db (Database): A pymongo ``Database`` object.
# value (str): Update value.
collection (str): Collection name.
# db (Database): A pymongo ``Database`` object.
match (dict): A dictionary containing the match criteria, for example::
# collection (str): Collection name.
{'seq_id': {'$in': ['a', 'b', 'c']}, 'cdr3_len': {'$gte': 18}}
"""
c = db[collection]
match = match if match is not None else {}
# check MongoDB version to use appropriate update command
if db.client.server_info()["version"].startswith("2"):
c.update(match, {"$set": {field: value}}, multi=True)
else:
c.update_many(match, {"$set": {field: value}})
# match (dict): A dictionary containing the match criteria, for example::
# {'seq_id': {'$in': ['a', 'b', 'c']}, 'cdr3_len': {'$gte': 18}}
# """
# c = db[collection]
# match = match if match is not None else {}
# # check MongoDB version to use appropriate update command
# if db.client.server_info()["version"].startswith("2"):
# c.update(match, {"$set": {field: value}}, multi=True)
# else:
# c.update_many(match, {"$set": {field: value}})
def unset(db, collection, field, match=None):
"""
Removes ``field`` from all records in ``collection`` that meet
``match`` criteria.
Arguments:
# def unset(db, collection, field, match=None):
# """
# Removes ``field`` from all records in ``collection`` that meet
# ``match`` criteria.
field (str): Field to be removed.
# Arguments:
db (Database): A pymongo Database object.
# field (str): Field to be removed.
collection (str): Collection name.
# db (Database): A pymongo Database object.
match (dict): A dictionary containing the match criteria, for example::
# collection (str): Collection name.
{'seq_id': {'$in': ['a', 'b', 'c']}, 'cdr3_len': {'$gte': 18}}
"""
c = db[collection]
match = match if match is not None else {}
# check MongoDB version to use appropriate update command
if db.client.server_info()["version"].startswith("2"):
c.update(match, {"$unset": {field: ""}}, multi=True)
else:
c.update_many(match, {"$unset": {field: ""}})
# match (dict): A dictionary containing the match criteria, for example::
# {'seq_id': {'$in': ['a', 'b', 'c']}, 'cdr3_len': {'$gte': 18}}
# """
# c = db[collection]
# match = match if match is not None else {}
# # check MongoDB version to use appropriate update command
# if db.client.server_info()["version"].startswith("2"):
# c.update(match, {"$unset": {field: ""}}, multi=True)
# else:
# c.update_many(match, {"$unset": {field: ""}})
def mongoimport(
json,
database,
ip="localhost",
port=27017,
user=None,
password=None,
delim="_",
delim1=None,
delim2=None,
delim_occurance=1,
delim1_occurance=1,
delim2_occurance=1,
):
"""
Performs mongoimport on one or more json files.
Args:
# def mongoimport(
# json,
# database,
# ip="localhost",
# port=27017,
# user=None,
# password=None,
# delim="_",
# delim1=None,
# delim2=None,
# delim_occurance=1,
# delim1_occurance=1,
# delim2_occurance=1,
# ):
# """
# Performs mongoimport on one or more json files.
json: Can be one of several things:
# Args:
- path to a single JSON file
- an iterable (list or tuple) of one or more JSON file paths
- path to a directory containing one or more JSON files
# json: Can be one of several things:
database (str): Name of the database into which the JSON files
will be imported
# - path to a single JSON file
# - an iterable (list or tuple) of one or more JSON file paths
# - path to a directory containing one or more JSON files
ip (str): IP address of the MongoDB server. Default is ``localhost``.
# database (str): Name of the database into which the JSON files
# will be imported
port (int): Port of the MongoDB database. Default is ``27017``.
# ip (str): IP address of the MongoDB server. Default is ``localhost``.
user (str): Username for the MongoDB database, if authentication is enabled.
Default is ``None``, which results in attempting connection without
authentication.
# port (int): Port of the MongoDB database. Default is ``27017``.
password (str): Password for the MongoDB database, if authentication is enabled.
Default is ``None``, which results in attempting connection without
authentication.
# user (str): Username for the MongoDB database, if authentication is enabled.
# Default is ``None``, which results in attempting connection without
# authentication.
delim (str): Delimiter, when generating collection names using a single delimiter.
Default is ``_``
# password (str): Password for the MongoDB database, if authentication is enabled.
# Default is ``None``, which results in attempting connection without
# authentication.
delim_occurance (int): Occurance at which to split filename when using a
single delimiter. Default is ``1``
# delim (str): Delimiter, when generating collection names using a single delimiter.
# Default is ``_``
delim1 (str): Left delimiter when splitting with two delimiters. Default is None.
# delim_occurance (int): Occurance at which to split filename when using a
# single delimiter. Default is ``1``
delim1_occurance (int): Occurance of ``delim1`` at which to split filename.
Default is ``1``
# delim1 (str): Left delimiter when splitting with two delimiters. Default is None.
delim2 (str): Right delimiter when splitting with two delimiters. Default is None.
# delim1_occurance (int): Occurance of ``delim1`` at which to split filename.
# Default is ``1``
delim2_occurance (int): Occurance of ``delim2`` at which to split filename.
Default is ``1``
"""
logger = log.get_logger("mongodb")
_print_mongoimport_info(logger)
if type(json) in (list, tuple):
pass
elif os.path.isdir(json):
from abtools.utils.pipeline import list_files
# delim2 (str): Right delimiter when splitting with two delimiters. Default is None.
json = list_files(json)
else:
json = [
json,
]
jsons = sorted([os.path.expanduser(j) for j in json if j.endswith(".json")])
collections = _get_import_collections(
jsons,
delim,
delim_occurance,
delim1,
delim1_occurance,
delim2,
delim2_occurance,
)
logger.info("Found {} files to import".format(len(jsons)))
logger.info("")
for i, (json_file, collection) in enumerate(zip(jsons, collections)):
logger.info(
"[ {} ] {} --> {}".format(i + 1, os.path.basename(json_file), collection)
)
# logger.info("Performing mongoimport on {}.".format(os.path.basename(json_file)))
# logger.info("Importing the file into collection {}.".format(collection))
if all([user, password]):
host = "--host {} --port {} -username {} -password {}".format(
ip, port, user, password
)
else:
host = "--host {} --port {}".format(ip, port)
mongo_cmd = "mongoimport {} --db {} --collection {} --file {}".format(
host, database, collection, json_file
)
mongo = sp.Popen(mongo_cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
stdout, stderr = mongo.communicate()
# delim2_occurance (int): Occurance of ``delim2`` at which to split filename.
# Default is ``1``
# """
# logger = log.get_logger("mongodb")
# _print_mongoimport_info(logger)
# if type(json) in (list, tuple):
# pass
# elif os.path.isdir(json):
# from abtools.utils.pipeline import list_files
# json = list_files(json)
# else:
# json = [
# json,
# ]
# jsons = sorted([os.path.expanduser(j) for j in json if j.endswith(".json")])
# collections = _get_import_collections(
# jsons,
# delim,
# delim_occurance,
# delim1,
# delim1_occurance,
# delim2,
# delim2_occurance,
# )
# logger.info("Found {} files to import".format(len(jsons)))
# logger.info("")
# for i, (json_file, collection) in enumerate(zip(jsons, collections)):
# logger.info(
# "[ {} ] {} --> {}".format(i + 1, os.path.basename(json_file), collection)
# )
# # logger.info("Performing mongoimport on {}.".format(os.path.basename(json_file)))
# # logger.info("Importing the file into collection {}.".format(collection))
# if all([user, password]):
# host = "--host {} --port {} -username {} -password {}".format(
# ip, port, user, password
# )
# else:
# host = "--host {} --port {}".format(ip, port)
# mongo_cmd = "mongoimport {} --db {} --collection {} --file {}".format(
# host, database, collection, json_file
# )
# mongo = sp.Popen(mongo_cmd, shell=True, stdout=sp.PIPE, stderr=sp.PIPE)
# stdout, stderr = mongo.communicate()
def index(db, collection, fields, directions=None, desc=False, background=False):
"""
Builds a simple (single field) or complex (multiple fields) index
on a single collection in a MongoDB database.
Args:
# def index(db, collection, fields, directions=None, desc=False, background=False):
# """
# Builds a simple (single field) or complex (multiple fields) index
# on a single collection in a MongoDB database.
db (Database): A pymongo Database object.
# Args:
collection (str): Collection name.
# db (Database): A pymongo Database object.
fields: Can be one of two things:
# collection (str): Collection name.
- the name of a single field, as a string
- an iterable (list/tuple) of one or more field names
# fields: Can be one of two things:
desc (bool): If ``True``, all indexes will be created in descending order.
Default is ``False``.
# - the name of a single field, as a string
# - an iterable (list/tuple) of one or more field names
directions (list): For complex indexes for which you'd like to have
different indexing directions (ascending for some fields, descending
for others), you can pass a list of pymongo direction objects (
pymongo.ASCENDING and pymongo.DESCENDING), in the same order as the
list of fields to be indexed. Must be the same length as the list
of index fields. Default is ``None``.
# desc (bool): If ``True``, all indexes will be created in descending order.
# Default is ``False``.
background (bool): If ``True``, the indexing operation will be processed
in the background. When performing background indexes, the MongoDB
database will not be locked.
"""
import pymongo
# directions (list): For complex indexes for which you'd like to have
# different indexing directions (ascending for some fields, descending
# for others), you can pass a list of pymongo direction objects (
# pymongo.ASCENDING and pymongo.DESCENDING), in the same order as the
# list of fields to be indexed. Must be the same length as the list
# of index fields. Default is ``None``.
if type(fields) in STR_TYPES:
fields = [
fields,
]
if directions is None:
_dir = pymongo.DESCENDING if desc else pymongo.ASCENDING
directions = [_dir] * len(fields)
field_tuples = list(zip(fields, directions))
coll = db[collection]
coll.create_index(field_tuples, background=background)
# background (bool): If ``True``, the indexing operation will be processed
# in the background. When performing background indexes, the MongoDB
# database will not be locked.
# """
# import pymongo
# if type(fields) in STR_TYPES:
# fields = [
# fields,
# ]
# if directions is None:
# _dir = pymongo.DESCENDING if desc else pymongo.ASCENDING
# directions = [_dir] * len(fields)
# field_tuples = list(zip(fields, directions))
# coll = db[collection]
# coll.create_index(field_tuples, background=background)
def remove_padding(db, collection, field="padding"):
"""
Removes a padding field.
Args:
# def remove_padding(db, collection, field="padding"):
# """
# Removes a padding field.
db (Database): A pymongo Database object.
# Args:
collection (str): Collection name
# db (Database): A pymongo Database object.
field (str): Name of the padding field. Default is ``padding``
"""
unset(db, collection, field=field)
# c = db[collection]
# c.update({}, {'$unset': {field: ''}}, multi=True)
# collection (str): Collection name
# field (str): Name of the padding field. Default is ``padding``
# """
# unset(db, collection, field=field)
# # c = db[collection]
# # c.update({}, {'$unset': {field: ''}}, multi=True)
def _get_import_collections(
jsons, delim, delim_occurance, delim1, delim1_occurance, delim2, delim2_occurance
):
jnames = [os.path.basename(j) for j in jsons]
if not all([delim1, delim2]):
collections = [delim.join(j.split(delim)[:delim_occurance]) for j in jnames]
else:
pre_colls = [delim1.join(j.split(delim1)[delim1_occurance:]) for j in jnames]
collections = [
delim2.join(j.split(delim2)[:delim2_occurance]) for j in pre_colls
]
return collections
# def _get_import_collections(
# jsons, delim, delim_occurance, delim1, delim1_occurance, delim2, delim2_occurance
# ):
# jnames = [os.path.basename(j) for j in jsons]
# if not all([delim1, delim2]):
# collections = [delim.join(j.split(delim)[:delim_occurance]) for j in jnames]
# else:
# pre_colls = [delim1.join(j.split(delim1)[delim1_occurance:]) for j in jnames]
# collections = [
# delim2.join(j.split(delim2)[:delim2_occurance]) for j in pre_colls
# ]
# return collections
def _print_mongoimport_info(logger):
logger.info("")
logger.info("")
logger.info("")
logger.info("-" * 25)
logger.info("MONGOIMPORT")
logger.info("-" * 25)
logger.info("")
# def _print_mongoimport_info(logger):
# logger.info("")
# logger.info("")
# logger.info("")
# logger.info("-" * 25)
# logger.info("MONGOIMPORT")
# logger.info("-" * 25)
# logger.info("")
def _print_remove_padding():
logger = log.get_logger("mongodb")
logger.info("Removing MongoDB padding...")
# def _print_remove_padding():
# logger = log.get_logger("mongodb")
# logger.info("Removing MongoDB padding...")

@@ -26,970 +26,970 @@ #!/usr/bin/env python

from __future__ import absolute_import, division, print_function, unicode_literals
# from __future__ import absolute_import, division, print_function, unicode_literals
import colorsys
import itertools
import math
import multiprocessing as mp
import os
import platform
import random
import shutil
import string
import subprocess as sp
import sys
import tempfile
from collections import Counter
from copy import copy, deepcopy
from typing import Optional
# import colorsys
# import itertools
# import math
# import multiprocessing as mp
# import os
# import platform
# import random
# import shutil
# import string
# import subprocess as sp
# import sys
# import tempfile
# from collections import Counter
# from copy import copy, deepcopy
# from typing import Optional
# import abstar
import ete3
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import seaborn as sns
# # import abstar
# import ete3
# import matplotlib.pyplot as plt
# import numpy as np
# import pandas as pd
# import scipy
# import seaborn as sns
# from abstar.core.germline import get_imgt_germlines
from Bio import AlignIO, Phylo
from matplotlib.colors import ListedColormap
from scipy.ndimage import gaussian_filter1d
from scipy.special import rel_entr
# # from abstar.core.germline import get_imgt_germlines
# from Bio import AlignIO, Phylo
# from matplotlib.colors import ListedColormap
# from scipy.ndimage import gaussian_filter1d
# from scipy.special import rel_entr
from ..core.pair import Pair
from ..core.sequence import Sequence
from .alignment import mafft, muscle
from .cluster import cluster
from .color import get_cmap, hex_to_rgb
from .decorators import lazy_property
from .pipeline import make_dir
# from ..core.pair import Pair
# from ..core.sequence import Sequence
# from .alignment import mafft, muscle
# from .cluster import cluster
# from .color import get_cmap, hex_to_rgb
# from .decorators import lazy_property
# from .pipeline import make_dir
if sys.version_info[0] > 2:
STR_TYPES = [
str,
]
else:
STR_TYPES = [str, unicode]
# if sys.version_info[0] > 2:
# STR_TYPES = [
# str,
# ]
# else:
# STR_TYPES = [str, unicode]
# --------------------------------
# PHYLOGENETIC RECONSTRUCTION
# --------------------------------
# # --------------------------------
# # PHYLOGENETIC RECONSTRUCTION
# # --------------------------------
def fasttree(
aln: str,
tree_file: Optional[str] = None,
is_aa: bool = False,
fasttree_bin: Optional[str] = None,
debug: bool = False,
quiet: bool = True,
) -> str:
"""
Computes a tree file from a multiple seqeunce alignment using `FastTree`_.
# def fasttree(
# aln: str,
# tree_file: Optional[str] = None,
# is_aa: bool = False,
# fasttree_bin: Optional[str] = None,
# debug: bool = False,
# quiet: bool = True,
# ) -> str:
# """
# Computes a tree file from a multiple seqeunce alignment using `FastTree`_.
Parameters
----------
aln : str
Path to a multiple sequence alignment file, in FASTA format, or a
FASTA-formatted multiple sequence alignment string. Required.
# Parameters
# ----------
# aln : str
# Path to a multiple sequence alignment file, in FASTA format, or a
# FASTA-formatted multiple sequence alignment string. Required.
tree_file : str
Path to the tree file which will be output by FastTree. If the parent
directory does not exist, it will be created. If not provided, the output
(a Newick-formatted tree file) will be returned as a ``str``.
# tree_file : str
# Path to the tree file which will be output by FastTree. If the parent
# directory does not exist, it will be created. If not provided, the output
# (a Newick-formatted tree file) will be returned as a ``str``.
is_aa : bool, default=False
Must be set to ``True`` if the input multiple sequence alignment contains
amino acid sequences. Default is ``False``, meaning FastTree will expect
nucleotide sequences.
# is_aa : bool, default=False
# Must be set to ``True`` if the input multiple sequence alignment contains
# amino acid sequences. Default is ``False``, meaning FastTree will expect
# nucleotide sequences.
fasttree_bin : str, optional
Path to the desired FastTree binary. Default is to use the version of
FastTree that is bundled with ``abutils``.
# fasttree_bin : str, optional
# Path to the desired FastTree binary. Default is to use the version of
# FastTree that is bundled with ``abutils``.
debug : bool, default=False
If ``True``, verbose output is printed. Default is False.
# debug : bool, default=False
# If ``True``, verbose output is printed. Default is False.
quiet : bool, default=True
Depricated, but retained for backwards compatibility. Use `debug` instead.
# quiet : bool, default=True
# Depricated, but retained for backwards compatibility. Use `debug` instead.
Returns
-------
tree_file: str
Path to the tree file produced by FastTree.
# Returns
# -------
# tree_file: str
# Path to the tree file produced by FastTree.
.. _FastTree:
http://www.microbesonline.org/fasttree/
"""
# process input
if os.path.isfile(aln):
alignment_file = os.path.abspath(aln)
else:
ff = tempfile.NamedTemporaryFile(delete=False)
ff.close()
alignment_file = ff.name
with open(alignment_file, "w") as f:
f.write(aln)
if not quiet:
debug = True
# set the FastTree binary
if fasttree_bin is None:
mod_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
fasttree_bin = os.path.join(
mod_dir, f"bin/fasttree_{platform.system().lower()}"
)
# make output directory if necessary
if tree_file is None:
as_file = False
tree_file = tempfile.NamedTemporaryFile(delete=False).name
else:
as_file = True
tree_file = os.path.abspath(tree_file)
if not os.path.isdir(os.path.dirname(tree_file)):
make_dir(os.path.dirname(tree_file))
# run FastTree
if is_aa:
ft_cmd = "fasttree {} > {}".format(alignment_file, tree_file)
else:
ft_cmd = "fasttree -nt {} > {}".format(alignment_file, tree_file)
ft = sp.Popen(ft_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
stdout, stderr = ft.communicate()
if debug:
print(ft_cmd)
print(stdout)
print(stderr)
# output
if as_file:
return tree_file
else:
with open(tree_file, "r") as f:
tree_string = f.read()
return tree_string
# .. _FastTree:
# http://www.microbesonline.org/fasttree/
# """
# # process input
# if os.path.isfile(aln):
# alignment_file = os.path.abspath(aln)
# else:
# ff = tempfile.NamedTemporaryFile(delete=False)
# ff.close()
# alignment_file = ff.name
# with open(alignment_file, "w") as f:
# f.write(aln)
# if not quiet:
# debug = True
# # set the FastTree binary
# if fasttree_bin is None:
# mod_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
# fasttree_bin = os.path.join(
# mod_dir, f"bin/fasttree_{platform.system().lower()}"
# )
# # make output directory if necessary
# if tree_file is None:
# as_file = False
# tree_file = tempfile.NamedTemporaryFile(delete=False).name
# else:
# as_file = True
# tree_file = os.path.abspath(tree_file)
# if not os.path.isdir(os.path.dirname(tree_file)):
# make_dir(os.path.dirname(tree_file))
# # run FastTree
# if is_aa:
# ft_cmd = "fasttree {} > {}".format(alignment_file, tree_file)
# else:
# ft_cmd = "fasttree -nt {} > {}".format(alignment_file, tree_file)
# ft = sp.Popen(ft_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
# stdout, stderr = ft.communicate()
# if debug:
# print(ft_cmd)
# print(stdout)
# print(stderr)
# # output
# if as_file:
# return tree_file
# else:
# with open(tree_file, "r") as f:
# tree_string = f.read()
# return tree_string
def lsd(
tree,
output_file=None,
dates_file=None,
outgroup_file=None,
with_constraints=True,
with_weights=True,
reestimate_root_position=None,
quiet=True,
):
lsd_cmd = "lsd -i {}".format(os.path.abspath(tree))
if output_file is not None:
lsd_cmd += " -o {}".format(os.path.abspath(output_file))
if dates_file is not None:
lsd_cmd += " -d {}".format(os.path.abspath(dates_file))
if outgroup_file is not None:
lsd_cmd += " -g {}".format(os.path.abspath(outgroup_file))
if with_constraints:
lsd_cmd += " -c"
if with_weights:
lsd_cmd += " -v"
if reestimate_root_position is not None:
lsd_cmd += " -r {}".format(reestimate_root_position)
p = sp.Popen(lsd_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
stdout, stderr = p.communicate()
if not quiet:
print(lsd_cmd)
print(stdout)
print(stderr)
return output_file
# def lsd(
# tree,
# output_file=None,
# dates_file=None,
# outgroup_file=None,
# with_constraints=True,
# with_weights=True,
# reestimate_root_position=None,
# quiet=True,
# ):
# lsd_cmd = "lsd -i {}".format(os.path.abspath(tree))
# if output_file is not None:
# lsd_cmd += " -o {}".format(os.path.abspath(output_file))
# if dates_file is not None:
# lsd_cmd += " -d {}".format(os.path.abspath(dates_file))
# if outgroup_file is not None:
# lsd_cmd += " -g {}".format(os.path.abspath(outgroup_file))
# if with_constraints:
# lsd_cmd += " -c"
# if with_weights:
# lsd_cmd += " -v"
# if reestimate_root_position is not None:
# lsd_cmd += " -r {}".format(reestimate_root_position)
# p = sp.Popen(lsd_cmd, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
# stdout, stderr = p.communicate()
# if not quiet:
# print(lsd_cmd)
# print(stdout)
# print(stderr)
# return output_file
def igphyml(
input_file: str = None,
tree_file: str = None,
root: str = None,
verbose: bool = False,
) -> str:
"""
Computes a phylogenetic tree using IgPhyML.
# def igphyml(
# input_file: str = None,
# tree_file: str = None,
# root: str = None,
# verbose: bool = False,
# ) -> str:
# """
# Computes a phylogenetic tree using IgPhyML.
.. note::
# .. note::
IgPhyML must be installed. It can be downloaded from https://github.com/kbhoehn/IgPhyML.
# IgPhyML must be installed. It can be downloaded from https://github.com/kbhoehn/IgPhyML.
Args:
# Args:
input_file (str): Path to a Phylip-formatted multiple sequence alignment. Required.
# input_file (str): Path to a Phylip-formatted multiple sequence alignment. Required.
tree_file (str): Path to the output tree file.
# tree_file (str): Path to the output tree file.
root (str): Name of the root sequence. Required.
# root (str): Name of the root sequence. Required.
verbose (bool): If `True`, prints the standard output and standard error for each IgPhyML run.
Default is `False`.
"""
# verbose (bool): If `True`, prints the standard output and standard error for each IgPhyML run.
# Default is `False`.
# """
if shutil.which("igphyml") is None:
raise RuntimeError(
"It appears that IgPhyML is not installed.\nPlease install and try again."
)
# if shutil.which("igphyml") is None:
# raise RuntimeError(
# "It appears that IgPhyML is not installed.\nPlease install and try again."
# )
# first, tree topology is estimated with the M0/GY94 model
igphyml_cmd1 = "igphyml -i {} -m GY -w M0 -t e --run_id gy94".format(aln_file)
p1 = sp.Popen(igphyml_cmd1, stdout=sp.PIPE, stderr=sp.PIPE)
stdout1, stderr1 = p1.communicate()
if verbose:
print(stdout1 + "\n")
print(stderr1 + "\n\n")
intermediate = input_file + "_igphyml_tree.txt_gy94"
# # first, tree topology is estimated with the M0/GY94 model
# igphyml_cmd1 = "igphyml -i {} -m GY -w M0 -t e --run_id gy94".format(aln_file)
# p1 = sp.Popen(igphyml_cmd1, stdout=sp.PIPE, stderr=sp.PIPE)
# stdout1, stderr1 = p1.communicate()
# if verbose:
# print(stdout1 + "\n")
# print(stderr1 + "\n\n")
# intermediate = input_file + "_igphyml_tree.txt_gy94"
# now we fit the HLP17 model once the tree topology is fixed
igphyml_cmd2 = "igphyml -i {0} -m HLP17 --root {1} -o lr -u {}_igphyml_tree.txt_gy94 -o {}".format(
input_file, root, tree_file
)
p2 = sp.Popen(igphyml_cmd2, stdout=sp.PIPE, stderr=sp.PIPE)
stdout2, stderr2 = p2.communicate()
if verbose:
print(stdout2 + "\n")
print(stderr2 + "\n")
return tree_file + "_igphyml_tree.txt"
# # now we fit the HLP17 model once the tree topology is fixed
# igphyml_cmd2 = "igphyml -i {0} -m HLP17 --root {1} -o lr -u {}_igphyml_tree.txt_gy94 -o {}".format(
# input_file, root, tree_file
# )
# p2 = sp.Popen(igphyml_cmd2, stdout=sp.PIPE, stderr=sp.PIPE)
# stdout2, stderr2 = p2.communicate()
# if verbose:
# print(stdout2 + "\n")
# print(stderr2 + "\n")
# return tree_file + "_igphyml_tree.txt"
# --------------------------------
# PHYLOGENETIC TREES
# --------------------------------
# # --------------------------------
# # PHYLOGENETIC TREES
# # --------------------------------
def phylogeny(
sequences=None,
project_dir=None,
name=None,
aln_file=None,
tree_file=None,
seq_field=None,
name_field=None,
aa=False,
species="human",
unrooted=False,
ladderize=True,
root=None,
root_name=None,
show_root_name=False,
color_dict=None,
color_function=None,
order_dict=None,
order_function=None,
color_node_labels=False,
label_colors=None,
scale=None,
branch_vert_margin=None,
fontsize=12,
show_names=True,
show_scale=False,
mirror=False,
min_order_fraction=0.1,
figname_prefix=None,
figname_suffix=None,
# linked_alignment=None, alignment_fontsize=11, alignment_height=50, alignment_width=50,
compact_alignment=False,
scale_factor=1,
rename_function=None,
linewidth=1.0,
delete_nodes=None,
quiet=True,
):
"""
Generates a lineage phylogeny figure.
# def phylogeny(
# sequences=None,
# project_dir=None,
# name=None,
# aln_file=None,
# tree_file=None,
# seq_field=None,
# name_field=None,
# aa=False,
# species="human",
# unrooted=False,
# ladderize=True,
# root=None,
# root_name=None,
# show_root_name=False,
# color_dict=None,
# color_function=None,
# order_dict=None,
# order_function=None,
# color_node_labels=False,
# label_colors=None,
# scale=None,
# branch_vert_margin=None,
# fontsize=12,
# show_names=True,
# show_scale=False,
# mirror=False,
# min_order_fraction=0.1,
# figname_prefix=None,
# figname_suffix=None,
# # linked_alignment=None, alignment_fontsize=11, alignment_height=50, alignment_width=50,
# compact_alignment=False,
# scale_factor=1,
# rename_function=None,
# linewidth=1.0,
# delete_nodes=None,
# quiet=True,
# ):
# """
# Generates a lineage phylogeny figure.
Args:
# Args:
sequences (list(Sequence)): A list of ``Sequence`` objects from which a phylogeny
will be calculated. Strictly speaking, they do not need to be ``Sequence`` objects,
rather, any object that contains the sequence name as the ``id`` attribute (or
by dictionary-style lookup using the provided ``name_field``) and contains the
sequence as the ``sequence`` attribute (or by dictionary-stype lookup using the
provided ``seq_field``).
# sequences (list(Sequence)): A list of ``Sequence`` objects from which a phylogeny
# will be calculated. Strictly speaking, they do not need to be ``Sequence`` objects,
# rather, any object that contains the sequence name as the ``id`` attribute (or
# by dictionary-style lookup using the provided ``name_field``) and contains the
# sequence as the ``sequence`` attribute (or by dictionary-stype lookup using the
# provided ``seq_field``).
project_dir (str): directory into which all phylogeny files will be deposited,
including alignment, tree and figure files.
# project_dir (str): directory into which all phylogeny files will be deposited,
# including alignment, tree and figure files.
name (str): Name to be used for naming alignment, tree, and phylogeny files. If not
provided, a random name will be generated.
# name (str): Name to be used for naming alignment, tree, and phylogeny files. If not
# provided, a random name will be generated.
aln_file (str): if a multiple sequence alignment has already been calculated,
passing the path to the alignment file (in FASTA format) will force Lineage.phylogeny()
to use the supplied msa instead of computing a new one.
# aln_file (str): if a multiple sequence alignment has already been calculated,
# passing the path to the alignment file (in FASTA format) will force Lineage.phylogeny()
# to use the supplied msa instead of computing a new one.
tree_file (str): if a tree file has already been calculated, passing the path
to the pre-computed tree file will force ``phylogeny()`` to use
the supplied tree file instead of computing a new one. It is important to note that
only sequence names will be parsed from the tree_file, so if ``order_function`` or
``color_function`` is also provided, ensure that these functions only require the
sequence ID rather than the entire sequence.
# tree_file (str): if a tree file has already been calculated, passing the path
# to the pre-computed tree file will force ``phylogeny()`` to use
# the supplied tree file instead of computing a new one. It is important to note that
# only sequence names will be parsed from the tree_file, so if ``order_function`` or
# ``color_function`` is also provided, ensure that these functions only require the
# sequence ID rather than the entire sequence.
aa (bool): if True, use amino acid sequences to compute the phylogeny.
Default is False.
# aa (bool): if True, use amino acid sequences to compute the phylogeny.
# Default is False.
root (Sequence, str: The root can be provided either as a ``Sequence`` object (if ``sequences``
are being provided) or as the name of a sequence that can be found either in
``sequences`` or in the provided ``aln_file`` or ``tree_file``. Note that if
either ``aln_file`` or ``tree_file`` are provided, the root must be provided
as the sequence name, not as a ``Sequence`` object (as the root sequence must
already be included in either ``aln_file`` or ``tree_file``. If the root is not
provided, the germline V-gene sequence of the
# root (Sequence, str: The root can be provided either as a ``Sequence`` object (if ``sequences``
# are being provided) or as the name of a sequence that can be found either in
# ``sequences`` or in the provided ``aln_file`` or ``tree_file``. Note that if
# either ``aln_file`` or ``tree_file`` are provided, the root must be provided
# as the sequence name, not as a ``Sequence`` object (as the root sequence must
# already be included in either ``aln_file`` or ``tree_file``. If the root is not
# provided, the germline V-gene sequence of the
color_dict (dict): Dictionary with sequence IDs as keys and colors (hex format) as values. If any
sequence IDs are not found in the dict, they will be colored black. If neither ``color_dict`` nor
``color_function`` is provided, all leaves will be colored black.
# color_dict (dict): Dictionary with sequence IDs as keys and colors (hex format) as values. If any
# sequence IDs are not found in the dict, they will be colored black. If neither ``color_dict`` nor
# ``color_function`` is provided, all leaves will be colored black.
color_function (func): Function that that accepts a ``Sequence`` object and returns the color
(as a hex code). If ``color_dict`` is also provided, ``color_function`` is ignored. Additionally,
``color_function`` will only be used if ``sequences`` are provided. If ``sequences`` are not provided
(instead using ``aln_file` or ``tree_file``), ``color_dict`` must be used instead of ``color_function``.
# color_function (func): Function that that accepts a ``Sequence`` object and returns the color
# (as a hex code). If ``color_dict`` is also provided, ``color_function`` is ignored. Additionally,
# ``color_function`` will only be used if ``sequences`` are provided. If ``sequences`` are not provided
# (instead using ``aln_file` or ``tree_file``), ``color_dict`` must be used instead of ``color_function``.
orders: a dictionary with sequence IDs as keys and orders (integers) as values.
If not provided, only the leaf branches will be colored (if <colors> or
<color_function> is provided).
# orders: a dictionary with sequence IDs as keys and orders (integers) as values.
# If not provided, only the leaf branches will be colored (if <colors> or
# <color_function> is provided).
chain: build a phylogeny using the given chain ('heavy' or 'light').
Default is 'heavy'.
# chain: build a phylogeny using the given chain ('heavy' or 'light').
# Default is 'heavy'.
filter_function: function used to filter sequences (identity-based clustering, for
example). The function should accept a list of Sequence objects and return
a list of Sequence objects.
# filter_function: function used to filter sequences (identity-based clustering, for
# example). The function should accept a list of Sequence objects and return
# a list of Sequence objects.
just_pairs: if True, compute the phylogeny using only paired sequences.
Default (False) will use all sequences of the appropriate chain, paired or not.
# just_pairs: if True, compute the phylogeny using only paired sequences.
# Default (False) will use all sequences of the appropriate chain, paired or not.
scale (float): passed to ete3.TreeStyle() to set the scale of the tree figure. Increased
scale results in a wider tree.
# scale (float): passed to ete3.TreeStyle() to set the scale of the tree figure. Increased
# scale results in a wider tree.
branch_vert_margin (int): passed to ete3.TreeStyle() to set the branch_vertical_margin of
the tree figure. Increased branch_vert_margin results in a taller tree.
# branch_vert_margin (int): passed to ete3.TreeStyle() to set the branch_vertical_margin of
# the tree figure. Increased branch_vert_margin results in a taller tree.
fontsize: size of the leaf labels. Default is 12.
# fontsize: size of the leaf labels. Default is 12.
show_names: show names of leaf nodes. Options are True (show labels for all leaf nodes),
False (don't show labels for any leaf nodes) or a list of sequence IDs for which
labels should be shown. Default is True.
# show_names: show names of leaf nodes. Options are True (show labels for all leaf nodes),
# False (don't show labels for any leaf nodes) or a list of sequence IDs for which
# labels should be shown. Default is True.
mirror: flip the orientation of the tree. Default is to draw the tree from left to right.
Setting mirror to True results in the tree being drawn from right to left.
# mirror: flip the orientation of the tree. Default is to draw the tree from left to right.
# Setting mirror to True results in the tree being drawn from right to left.
min_order_fraction: minimum fraction of downstream leaves requried to color a branch.
When coloring non-leaf nodes, the earliest 'order' with at least <min_order_fraction>
leaf nodes is used. Default is 0.1 (which corresponds to 10%).
# min_order_fraction: minimum fraction of downstream leaves requried to color a branch.
# When coloring non-leaf nodes, the earliest 'order' with at least <min_order_fraction>
# leaf nodes is used. Default is 0.1 (which corresponds to 10%).
figname_prefix: by default, figures will be named <lineage_id>.pdf. If prefix='prefix_' and
the lineage ID is 'ABC123', the figure file will be named 'prefix_ABC123.pdf'.
# figname_prefix: by default, figures will be named <lineage_id>.pdf. If prefix='prefix_' and
# the lineage ID is 'ABC123', the figure file will be named 'prefix_ABC123.pdf'.
figname_suffix: by default, figures will be named <lineage_id>.pdf. If suffix='_suffix' and
the lineage ID is 'ABC123', the figure file will be named 'ABC123_suffix.pdf'.
"""
# figname_suffix: by default, figures will be named <lineage_id>.pdf. If suffix='_suffix' and
# the lineage ID is 'ABC123', the figure file will be named 'ABC123_suffix.pdf'.
# """
from abstar.core.germline import get_imgt_germlines
# from abstar.core.germline import get_imgt_germlines
if project_dir is None:
print("\nERROR: project_dir is required\n")
sys.exit(1)
else:
project_dir = os.path.abspath(project_dir)
# if project_dir is None:
# print("\nERROR: project_dir is required\n")
# sys.exit(1)
# else:
# project_dir = os.path.abspath(project_dir)
# make a name if one isn't provided
if name is None:
name = "".join(
random.choice(string.ascii_uppercase + string.digits) for _ in range(8)
)
# # make a name if one isn't provided
# if name is None:
# name = "".join(
# random.choice(string.ascii_uppercase + string.digits) for _ in range(8)
# )
# if sequences are provided, need to process them
if sequences is not None and all([arg is None for arg in [aln_file, tree_file]]):
sequences = deepcopy(sequences)
root = copy(root)
# # if sequences are provided, need to process them
# if sequences is not None and all([arg is None for arg in [aln_file, tree_file]]):
# sequences = deepcopy(sequences)
# root = copy(root)
# if custom seq_field is specified, copy to the .sequence attribute
if seq_field is not None:
if not all([seq_field in list(s.annotations.keys()) for s in sequences]):
print(
"\nERROR: {} is not present in all of the supplied sequences.\n".format(
seq_field
)
)
sys.exit(1)
for s in sequences:
s.alignment_sequence = s[seq_field]
else:
for s in sequences:
s.alignment_sequence = s.sequence
# # if custom seq_field is specified, copy to the .sequence attribute
# if seq_field is not None:
# if not all([seq_field in list(s.annotations.keys()) for s in sequences]):
# print(
# "\nERROR: {} is not present in all of the supplied sequences.\n".format(
# seq_field
# )
# )
# sys.exit(1)
# for s in sequences:
# s.alignment_sequence = s[seq_field]
# else:
# for s in sequences:
# s.alignment_sequence = s.sequence
# if custom name_field is specified, copy to the .id attribute
if name_field is not None:
if not all([name_field in list(s.annotations.keys()) for s in sequences]):
print(
"\nERROR: {} is not present in all of the supplied sequences.\n".format(
name_field
)
)
sys.exit(1)
for s in sequences:
s.alignment_id = s[name_field]
else:
for s in sequences:
s.alignment_id = s.id
# # if custom name_field is specified, copy to the .id attribute
# if name_field is not None:
# if not all([name_field in list(s.annotations.keys()) for s in sequences]):
# print(
# "\nERROR: {} is not present in all of the supplied sequences.\n".format(
# name_field
# )
# )
# sys.exit(1)
# for s in sequences:
# s.alignment_id = s[name_field]
# else:
# for s in sequences:
# s.alignment_id = s.id
# parse the root sequence
if unrooted:
root = None
root_name = None
elif root is None:
if not quiet:
print(
"\nRoot sequence was was not provided. Using the germline V-gene."
)
if not all(["v_gene" in list(s.annotations.keys()) for s in sequences]):
print(
"\nInput sequences to not appear to be AbStar annotated. Annotating now..."
)
sequences = abstar.run(*[(s.id, s.sequence) for s in sequences])
print("Done.")
if not all(["full" in list(s["v_gene"].keys()) for s in sequences]):
print(
"\nInput sequences to not appear to be AbStar annotated. Annotating now..."
)
sequences = abstar.run(*[(s.id, s.sequence) for s in sequences])
print("Done.")
top_vgene = sorted(
list(Counter([s["v_gene"]["full"] for s in sequences]).items()),
key=lambda x: x[1],
reverse=True,
)[0][0]
vgene = get_imgt_germlines(species, "V", gene=top_vgene)
if aa:
root = Sequence(vgene.ungapped_aa_sequence, id=top_vgene)
else:
root = Sequence(vgene.ungapped_nt_sequence, id=top_vgene)
root.alignment_id = root.id
root.alignment_sequence = root.sequence
if not quiet:
print("Top V-gene: {}".format(root.alignment_id))
print(root.alignment_sequence)
elif type(root) in STR_TYPES:
root = [s for s in sequences if s.alignment_id == root][0]
if not root:
print(
"\nERROR: The name of the root sequence ({}) was not found in the list of input sequences.".format(
root
)
)
print("\n")
sys.exit(1)
sequences = [s for s in sequences if s.alignment_id != root.alignment_id]
elif type(root) == Sequence:
if seq_field is not None:
if seq_field not in list(root.anotations.keys()):
print(
"\nERROR: {} is not present in the supplied root sequence.\n".format(
seq_field
)
)
sys.exit(1)
root.alignment_sequence = root[seq_field]
if name_field is not None:
if name_field not in list(root.anotations.keys()):
print(
"\nERROR: {} is not present in the supplied root sequence.\n".format(
name_field
)
)
sys.exit(1)
root.alignment_id = root[name_field]
sequences = [s for s in sequences if s.alignment_id != root.alignment_id]
else:
print(
"\nERROR: If root is provided, it must be the name of a sequence \
found in the supplied list of sequences or it must be a Sequence object."
)
print("\n")
sys.exit(1)
if not unrooted:
if root_name is not None:
root.alignment_id = root_name
else:
root_name = root.alignment_id
sequences.append(root)
# # parse the root sequence
# if unrooted:
# root = None
# root_name = None
# elif root is None:
# if not quiet:
# print(
# "\nRoot sequence was was not provided. Using the germline V-gene."
# )
# if not all(["v_gene" in list(s.annotations.keys()) for s in sequences]):
# print(
# "\nInput sequences to not appear to be AbStar annotated. Annotating now..."
# )
# sequences = abstar.run(*[(s.id, s.sequence) for s in sequences])
# print("Done.")
# if not all(["full" in list(s["v_gene"].keys()) for s in sequences]):
# print(
# "\nInput sequences to not appear to be AbStar annotated. Annotating now..."
# )
# sequences = abstar.run(*[(s.id, s.sequence) for s in sequences])
# print("Done.")
# top_vgene = sorted(
# list(Counter([s["v_gene"]["full"] for s in sequences]).items()),
# key=lambda x: x[1],
# reverse=True,
# )[0][0]
# vgene = get_imgt_germlines(species, "V", gene=top_vgene)
# if aa:
# root = Sequence(vgene.ungapped_aa_sequence, id=top_vgene)
# else:
# root = Sequence(vgene.ungapped_nt_sequence, id=top_vgene)
# root.alignment_id = root.id
# root.alignment_sequence = root.sequence
# if not quiet:
# print("Top V-gene: {}".format(root.alignment_id))
# print(root.alignment_sequence)
# elif type(root) in STR_TYPES:
# root = [s for s in sequences if s.alignment_id == root][0]
# if not root:
# print(
# "\nERROR: The name of the root sequence ({}) was not found in the list of input sequences.".format(
# root
# )
# )
# print("\n")
# sys.exit(1)
# sequences = [s for s in sequences if s.alignment_id != root.alignment_id]
# elif type(root) == Sequence:
# if seq_field is not None:
# if seq_field not in list(root.anotations.keys()):
# print(
# "\nERROR: {} is not present in the supplied root sequence.\n".format(
# seq_field
# )
# )
# sys.exit(1)
# root.alignment_sequence = root[seq_field]
# if name_field is not None:
# if name_field not in list(root.anotations.keys()):
# print(
# "\nERROR: {} is not present in the supplied root sequence.\n".format(
# name_field
# )
# )
# sys.exit(1)
# root.alignment_id = root[name_field]
# sequences = [s for s in sequences if s.alignment_id != root.alignment_id]
# else:
# print(
# "\nERROR: If root is provided, it must be the name of a sequence \
# found in the supplied list of sequences or it must be a Sequence object."
# )
# print("\n")
# sys.exit(1)
# if not unrooted:
# if root_name is not None:
# root.alignment_id = root_name
# else:
# root_name = root.alignment_id
# sequences.append(root)
# parse sequences from aln_file, if provided
elif aln_file is not None:
if not unrooted and type(root) not in STR_TYPES:
print(
"\nERROR: If providing an aln_file, the name of the root sequence must \
be provided (as a string) using the root keyword argument"
)
print("\n")
sys.exit(1)
_sequences = []
_root = None
for rec in AlignIO.read(open(aln_file), "fasta"):
s = str(rec.seq).replace("-", "")
if rec.id == root:
_root = Sequence(s, rec.id)
_root.alignment_id = _root.id
else:
_s = Sequence(s, id=rec.id)
_s.alignment_id = rec.id
_sequences.append(_s)
if sequences is None:
sequences = _sequences
else:
sequence_ids = [s.id for s in sequences]
if any([_s.alignment_id not in sequence_ids for _s in _sequences]):
print(
"\nWARNING: Sequences were found in the alignment file that were not included \
in the input sequence list. This may cause problems."
)
for s in sequences:
s.alignment_id = s.id
s.alignment_sequence = s.sequence
if unrooted:
root = None
root_name = None
else:
if _root is None:
print(
"\nERROR: The specified root ({}) was not found in the provided alignment file.".format(
root
)
)
print("\n")
sys.exit(1)
root = _root
if root_name is not None:
root.alignment_id = root_name
else:
root_name = root.alignment_id
sequences = [
s
for s in sequences
if all(
[s.alignment_id != name for name in [root.id, root.alignment_id]]
)
]
sequences.append(root)
# # parse sequences from aln_file, if provided
# elif aln_file is not None:
# if not unrooted and type(root) not in STR_TYPES:
# print(
# "\nERROR: If providing an aln_file, the name of the root sequence must \
# be provided (as a string) using the root keyword argument"
# )
# print("\n")
# sys.exit(1)
# _sequences = []
# _root = None
# for rec in AlignIO.read(open(aln_file), "fasta"):
# s = str(rec.seq).replace("-", "")
# if rec.id == root:
# _root = Sequence(s, rec.id)
# _root.alignment_id = _root.id
# else:
# _s = Sequence(s, id=rec.id)
# _s.alignment_id = rec.id
# _sequences.append(_s)
# if sequences is None:
# sequences = _sequences
# else:
# sequence_ids = [s.id for s in sequences]
# if any([_s.alignment_id not in sequence_ids for _s in _sequences]):
# print(
# "\nWARNING: Sequences were found in the alignment file that were not included \
# in the input sequence list. This may cause problems."
# )
# for s in sequences:
# s.alignment_id = s.id
# s.alignment_sequence = s.sequence
# if unrooted:
# root = None
# root_name = None
# else:
# if _root is None:
# print(
# "\nERROR: The specified root ({}) was not found in the provided alignment file.".format(
# root
# )
# )
# print("\n")
# sys.exit(1)
# root = _root
# if root_name is not None:
# root.alignment_id = root_name
# else:
# root_name = root.alignment_id
# sequences = [
# s
# for s in sequences
# if all(
# [s.alignment_id != name for name in [root.id, root.alignment_id]]
# )
# ]
# sequences.append(root)
# parse sequences from tree_file, if provided
elif tree_file is not None:
if not unrooted and type(root) not in STR_TYPES:
print(
"\nERROR: If providing a tree_file, the name of the root sequence must \
be provided (as a string) using the root keyword argument"
)
print("\n")
sys.exit(1)
_sequences = []
_root = None
tree = Phylo.read(open(tree_file), "newick")
for leaf in tree.get_terminals():
s = ""
if leaf.name == root:
_root = Sequence(s, leaf.name)
_root.alignment_id = _root.id
else:
_s = Sequence(s, id=leaf.name)
_s.alignment_id = leaf.name
_sequences.append(_s)
if sequences is None:
sequences = _sequences
else:
sequence_ids = [s.id for s in sequences]
if any([_s.alignment_id not in sequence_ids for _s in _sequences]):
print(
"\nWARNING: Sequences were found in the alignment file that were not included \
in the input sequence list. This may cause problems."
)
for s in sequences:
s.alignment_id = s.id
s.alignment_sequence = s.sequence
if unrooted:
root = None
root_name = None
elif _root is None:
print(
"\nERROR: The specified root ({}) was not found in the provided tree file.".format(
root
)
)
print("\n")
sys.exit(1)
else:
root = _root
if root_name is not None:
root.alignment_id = root_name
else:
root_name = root.alignment_id
sequences = [
s
for s in sequences
if all(
[s.alignment_id != name for name in [root.id, root.alignment_id]]
)
]
sequences.append(root)
# # parse sequences from tree_file, if provided
# elif tree_file is not None:
# if not unrooted and type(root) not in STR_TYPES:
# print(
# "\nERROR: If providing a tree_file, the name of the root sequence must \
# be provided (as a string) using the root keyword argument"
# )
# print("\n")
# sys.exit(1)
# _sequences = []
# _root = None
# tree = Phylo.read(open(tree_file), "newick")
# for leaf in tree.get_terminals():
# s = ""
# if leaf.name == root:
# _root = Sequence(s, leaf.name)
# _root.alignment_id = _root.id
# else:
# _s = Sequence(s, id=leaf.name)
# _s.alignment_id = leaf.name
# _sequences.append(_s)
# if sequences is None:
# sequences = _sequences
# else:
# sequence_ids = [s.id for s in sequences]
# if any([_s.alignment_id not in sequence_ids for _s in _sequences]):
# print(
# "\nWARNING: Sequences were found in the alignment file that were not included \
# in the input sequence list. This may cause problems."
# )
# for s in sequences:
# s.alignment_id = s.id
# s.alignment_sequence = s.sequence
# if unrooted:
# root = None
# root_name = None
# elif _root is None:
# print(
# "\nERROR: The specified root ({}) was not found in the provided tree file.".format(
# root
# )
# )
# print("\n")
# sys.exit(1)
# else:
# root = _root
# if root_name is not None:
# root.alignment_id = root_name
# else:
# root_name = root.alignment_id
# sequences = [
# s
# for s in sequences
# if all(
# [s.alignment_id != name for name in [root.id, root.alignment_id]]
# )
# ]
# sequences.append(root)
# set up colors and color ordering
if order_dict is None:
if order_function is not None:
order_dict = {seq.alignment_id: order_function(seq) for seq in sequences}
if color_dict is None:
if color_function is not None:
color_dict = {seq.alignment_id: color_function(seq) for seq in sequences}
if color_dict is None:
color_dict = {}
# # set up colors and color ordering
# if order_dict is None:
# if order_function is not None:
# order_dict = {seq.alignment_id: order_function(seq) for seq in sequences}
# if color_dict is None:
# if color_function is not None:
# color_dict = {seq.alignment_id: color_function(seq) for seq in sequences}
# if color_dict is None:
# color_dict = {}
# make msa (if necessary)
if all([aln_file is None, tree_file is None]):
aln_file = os.path.abspath(os.path.join(project_dir, "{}.aln".format(name)))
# muscle(seqs, aln_file, as_file=True)
do_print = False if quiet else True
if do_print:
print("\n")
seqs = [(s.alignment_id, s.alignment_sequence) for s in sequences]
mafft(
seqs, aln_file, as_file=True, print_stdout=do_print, print_stderr=do_print
)
# # make msa (if necessary)
# if all([aln_file is None, tree_file is None]):
# aln_file = os.path.abspath(os.path.join(project_dir, "{}.aln".format(name)))
# # muscle(seqs, aln_file, as_file=True)
# do_print = False if quiet else True
# if do_print:
# print("\n")
# seqs = [(s.alignment_id, s.alignment_sequence) for s in sequences]
# mafft(
# seqs, aln_file, as_file=True, print_stdout=do_print, print_stderr=do_print
# )
# make treefile (if necessary)
if tree_file is None:
tree_file = os.path.abspath(os.path.join(project_dir, "{}.nw".format(name)))
fasttree(aln_file, tree_file, is_aa=aa, quiet=quiet)
# # make treefile (if necessary)
# if tree_file is None:
# tree_file = os.path.abspath(os.path.join(project_dir, "{}.nw".format(name)))
# fasttree(aln_file, tree_file, is_aa=aa, quiet=quiet)
# make phylogeny
prefix = "" if figname_prefix is None else figname_prefix
suffix = "" if figname_suffix is None else figname_suffix
fig_file = os.path.join(project_dir, "{}{}{}.pdf".format(prefix, name, suffix))
_make_tree_figure(
tree_file,
fig_file,
color_dict,
order_dict,
None if root is None else root.alignment_id,
rename_function=rename_function,
show_names=show_names,
name_field=name_field,
branch_vert_margin=branch_vert_margin,
scale=scale,
color_node_labels=color_node_labels,
label_colors=label_colors,
show_root_name=show_root_name,
tree_orientation=1 if mirror else 0,
fontsize=fontsize,
min_order_fraction=min_order_fraction,
# linked_alignment=linked_alignment,
# alignment_fontsize=alignment_fontsize,
# alignment_height=alignment_height,
# alignment_width=alignment_width,
show_scale=show_scale,
compact_alignment=compact_alignment,
scale_factor=scale_factor,
linewidth=linewidth,
ladderize=ladderize,
delete_nodes=delete_nodes,
)
# # make phylogeny
# prefix = "" if figname_prefix is None else figname_prefix
# suffix = "" if figname_suffix is None else figname_suffix
# fig_file = os.path.join(project_dir, "{}{}{}.pdf".format(prefix, name, suffix))
# _make_tree_figure(
# tree_file,
# fig_file,
# color_dict,
# order_dict,
# None if root is None else root.alignment_id,
# rename_function=rename_function,
# show_names=show_names,
# name_field=name_field,
# branch_vert_margin=branch_vert_margin,
# scale=scale,
# color_node_labels=color_node_labels,
# label_colors=label_colors,
# show_root_name=show_root_name,
# tree_orientation=1 if mirror else 0,
# fontsize=fontsize,
# min_order_fraction=min_order_fraction,
# # linked_alignment=linked_alignment,
# # alignment_fontsize=alignment_fontsize,
# # alignment_height=alignment_height,
# # alignment_width=alignment_width,
# show_scale=show_scale,
# compact_alignment=compact_alignment,
# scale_factor=scale_factor,
# linewidth=linewidth,
# ladderize=ladderize,
# delete_nodes=delete_nodes,
# )
def _make_tree_figure(
tree,
fig,
colors,
orders,
root_name,
scale=None,
branch_vert_margin=None,
fontsize=12,
show_names=True,
name_field="seq_id",
rename_function=None,
color_node_labels=False,
label_colors=None,
tree_orientation=0,
min_order_fraction=0.1,
show_root_name=False,
chain=None,
# linked_alignment=None, alignment_fontsize=11, alignment_height=50, alignment_width=50,
compact_alignment=False,
scale_factor=1,
linewidth=1,
show_scale=False,
ladderize=True,
delete_nodes=None,
):
if delete_nodes is None:
delete_nodes = []
elif type(delete_nodes) in STR_TYPES:
delete_nodes = [
delete_nodes,
]
if show_root_name is True:
show_names.append(root_name)
# if linked_alignment is not None:
# t = ete3.PhyloTree(tree, alignment=linked_alignment, alg_format='fasta')
# ete3.faces.SequenceItem = MySequenceItem
t = ete3.Tree(tree)
if root_name is not None:
t.set_outgroup(t & root_name)
# style the nodes
for node in t.traverse():
if node.name in delete_nodes:
node.delete()
continue
if orders is not None:
leaves = node.get_leaf_names()
order_count = Counter([orders[l] for l in leaves])
for order in sorted(order_count.keys()):
if float(order_count[order]) / len(leaves) >= min_order_fraction:
color = colors[order]
break
else:
color = colors.get(node.name, "#000000")
# if linked_alignment is not None:
# node.add_feature('aln_fontsize', alignment_fontsize)
# node.add_feature('aln_height', alignment_height)
# node.add_feature('aln_width', alignment_width)
# node.add_feature('fontsize', fontsize)
# node.add_feature('format', 'seq')
# node.add_feature('scale_factor', scale_factor)
style = ete3.NodeStyle()
style["size"] = 0
style["vt_line_width"] = float(linewidth)
style["hz_line_width"] = float(linewidth)
style["vt_line_color"] = color
style["hz_line_color"] = color
style["vt_line_type"] = 0
style["hz_line_type"] = 0
if show_names is True:
tf = _build_node_text_face(
node, color_node_labels, color, label_colors, fontsize, rename_function
)
node.add_face(tf, column=0)
elif node.name in show_names:
tf = _build_node_text_face(
node, color_node_labels, color, label_colors, fontsize, rename_function
)
node.add_face(tf, column=0)
node.set_style(style)
t.dist = 0
ts = ete3.TreeStyle()
# if linked_alignment is not None:
# ts.layout_fn = _phyloalignment_layout_function
ts.orientation = tree_orientation
ts.show_leaf_name = False
if scale is not None:
ts.scale = int(scale)
if branch_vert_margin is not None:
ts.branch_vertical_margin = float(branch_vert_margin)
ts.show_scale = show_scale
if ladderize:
t.ladderize()
t.render(fig, tree_style=ts)
# def _make_tree_figure(
# tree,
# fig,
# colors,
# orders,
# root_name,
# scale=None,
# branch_vert_margin=None,
# fontsize=12,
# show_names=True,
# name_field="seq_id",
# rename_function=None,
# color_node_labels=False,
# label_colors=None,
# tree_orientation=0,
# min_order_fraction=0.1,
# show_root_name=False,
# chain=None,
# # linked_alignment=None, alignment_fontsize=11, alignment_height=50, alignment_width=50,
# compact_alignment=False,
# scale_factor=1,
# linewidth=1,
# show_scale=False,
# ladderize=True,
# delete_nodes=None,
# ):
# if delete_nodes is None:
# delete_nodes = []
# elif type(delete_nodes) in STR_TYPES:
# delete_nodes = [
# delete_nodes,
# ]
# if show_root_name is True:
# show_names.append(root_name)
# # if linked_alignment is not None:
# # t = ete3.PhyloTree(tree, alignment=linked_alignment, alg_format='fasta')
# # ete3.faces.SequenceItem = MySequenceItem
# t = ete3.Tree(tree)
# if root_name is not None:
# t.set_outgroup(t & root_name)
# # style the nodes
# for node in t.traverse():
# if node.name in delete_nodes:
# node.delete()
# continue
# if orders is not None:
# leaves = node.get_leaf_names()
# order_count = Counter([orders[l] for l in leaves])
# for order in sorted(order_count.keys()):
# if float(order_count[order]) / len(leaves) >= min_order_fraction:
# color = colors[order]
# break
# else:
# color = colors.get(node.name, "#000000")
# # if linked_alignment is not None:
# # node.add_feature('aln_fontsize', alignment_fontsize)
# # node.add_feature('aln_height', alignment_height)
# # node.add_feature('aln_width', alignment_width)
# # node.add_feature('fontsize', fontsize)
# # node.add_feature('format', 'seq')
# # node.add_feature('scale_factor', scale_factor)
# style = ete3.NodeStyle()
# style["size"] = 0
# style["vt_line_width"] = float(linewidth)
# style["hz_line_width"] = float(linewidth)
# style["vt_line_color"] = color
# style["hz_line_color"] = color
# style["vt_line_type"] = 0
# style["hz_line_type"] = 0
# if show_names is True:
# tf = _build_node_text_face(
# node, color_node_labels, color, label_colors, fontsize, rename_function
# )
# node.add_face(tf, column=0)
# elif node.name in show_names:
# tf = _build_node_text_face(
# node, color_node_labels, color, label_colors, fontsize, rename_function
# )
# node.add_face(tf, column=0)
# node.set_style(style)
# t.dist = 0
# ts = ete3.TreeStyle()
# # if linked_alignment is not None:
# # ts.layout_fn = _phyloalignment_layout_function
# ts.orientation = tree_orientation
# ts.show_leaf_name = False
# if scale is not None:
# ts.scale = int(scale)
# if branch_vert_margin is not None:
# ts.branch_vertical_margin = float(branch_vert_margin)
# ts.show_scale = show_scale
# if ladderize:
# t.ladderize()
# t.render(fig, tree_style=ts)
def _build_node_text_face(
node, color_node_labels, color, label_colors, fontsize, rename_function
):
if color_node_labels:
if label_colors is None:
node_color = color
elif type(label_colors) == dict:
node_color = label_colors.get(node.name, "#000000")
elif type(label_colors) in [list, tuple]:
node_color = color if node.name in label_colors else "#000000"
else:
node_color = "#000000"
else:
node_color = "#000000"
node_name = node.name if rename_function is None else rename_function(node.name)
tf = ete3.TextFace(node_name, fsize=fontsize, fgcolor=node_color)
return tf
# def _build_node_text_face(
# node, color_node_labels, color, label_colors, fontsize, rename_function
# ):
# if color_node_labels:
# if label_colors is None:
# node_color = color
# elif type(label_colors) == dict:
# node_color = label_colors.get(node.name, "#000000")
# elif type(label_colors) in [list, tuple]:
# node_color = color if node.name in label_colors else "#000000"
# else:
# node_color = "#000000"
# else:
# node_color = "#000000"
# node_name = node.name if rename_function is None else rename_function(node.name)
# tf = ete3.TextFace(node_name, fsize=fontsize, fgcolor=node_color)
# return tf
# -----------------------------------------
# PHYLOGENETIC LAPLACIAN SPECTRA
# -----------------------------------------
# # -----------------------------------------
# # PHYLOGENETIC LAPLACIAN SPECTRA
# # -----------------------------------------
class PLS:
"""
Object representing a collection of Phylogenetic Laplacian Spectra.
"""
# class PLS:
# """
# Object representing a collection of Phylogenetic Laplacian Spectra.
# """
def __init__(self, spectral_trees, n_processes=0):
self.trees = spectral_trees
self.n_processes = mp.cpu_count() if n_processes == 0 else n_processes
self.compute_kdes()
# def __init__(self, spectral_trees, n_processes=0):
# self.trees = spectral_trees
# self.n_processes = mp.cpu_count() if n_processes == 0 else n_processes
# self.compute_kdes()
@lazy_property
def distance_matrix(self):
dist = {}
p = mp.Pool(processes=self.n_processes)
async_results = []
for t in self.trees:
async_results.append(
p.apply_async(self.calc_distance_row, args=(t, self.trees))
)
dists = [ar.get() for ar in async_results]
p.close()
p.join()
for t1, row in zip(self.trees, dists):
for t2, d in zip(self.trees, row):
dist[t1.name][t2.name] = d
return pd.DataFrame(dist)
# @lazy_property
# def distance_matrix(self):
# dist = {}
# p = mp.Pool(processes=self.n_processes)
# async_results = []
# for t in self.trees:
# async_results.append(
# p.apply_async(self.calc_distance_row, args=(t, self.trees))
# )
# dists = [ar.get() for ar in async_results]
# p.close()
# p.join()
# for t1, row in zip(self.trees, dists):
# for t2, d in zip(self.trees, row):
# dist[t1.name][t2.name] = d
# return pd.DataFrame(dist)
# for t1, t2 in itertools.product(self.trees, repeat=2):
# dist[t1.name] = {}
# if t1.name not in dist:
# dist[t1.name] = {}
# dist[t1.name][t2.name] = self.compare_trees(t1, t2)
# return pd.DataFrame(dist)
# # for t1, t2 in itertools.product(self.trees, repeat=2):
# # dist[t1.name] = {}
# # if t1.name not in dist:
# # dist[t1.name] = {}
# # dist[t1.name][t2.name] = self.compare_trees(t1, t2)
# # return pd.DataFrame(dist)
def calc_distance_row(self, t1, trees):
dists = []
for t2 in trees:
dists.append(self.compare_trees(t1, t2))
return dists
# def calc_distance_row(self, t1, trees):
# dists = []
# for t2 in trees:
# dists.append(self.compare_trees(t1, t2))
# return dists
def compare_trees(self, t1, t2, n_points=100):
maxw = max([t1.principal_eigenvalue, t2.principal_eigenvalue])
x = np.linspace(0, maxw, n_points)
p = t1.pdf(x)
q = t2.pdf(x)
return self.jensenshannon(p, q)
# def compare_trees(self, t1, t2, n_points=100):
# maxw = max([t1.principal_eigenvalue, t2.principal_eigenvalue])
# x = np.linspace(0, maxw, n_points)
# p = t1.pdf(x)
# q = t2.pdf(x)
# return self.jensenshannon(p, q)
@staticmethod
def jensenshannon(p, q, base=None):
p = np.asarray(p)
q = np.asarray(q)
p = p / np.sum(p, axis=0)
q = q / np.sum(q, axis=0)
m = (p + q) / 2.0
left = rel_entr(p, m)
right = rel_entr(q, m)
js = np.sum(left, axis=0) + np.sum(right, axis=0)
if base is not None:
js /= np.log(base)
return np.sqrt(js / 2.0)
# @staticmethod
# def jensenshannon(p, q, base=None):
# p = np.asarray(p)
# q = np.asarray(q)
# p = p / np.sum(p, axis=0)
# q = q / np.sum(q, axis=0)
# m = (p + q) / 2.0
# left = rel_entr(p, m)
# right = rel_entr(q, m)
# js = np.sum(left, axis=0) + np.sum(right, axis=0)
# if base is not None:
# js /= np.log(base)
# return np.sqrt(js / 2.0)
def compute_kdes(self):
if self.n_processes == 1:
self._compute_kdes_sp()
else:
self._compute_kdes_mp()
# def compute_kdes(self):
# if self.n_processes == 1:
# self._compute_kdes_sp()
# else:
# self._compute_kdes_mp()
def _compute_kdes_sp(self):
for t in self.trees:
kde = t.kde
# def _compute_kdes_sp(self):
# for t in self.trees:
# kde = t.kde
def _compute_kdes_mp(self):
p = mp.Pool(processes=self.n_processes)
async_results = []
for t in self.trees:
async_results.append(p.apply_async(self.compute_kde, args=(t,)))
kdes = [ar.get() for ar in async_results]
for t, kde in zip(self.trees, kdes):
t.kde = kde
p.close()
p.join()
# def _compute_kdes_mp(self):
# p = mp.Pool(processes=self.n_processes)
# async_results = []
# for t in self.trees:
# async_results.append(p.apply_async(self.compute_kde, args=(t,)))
# kdes = [ar.get() for ar in async_results]
# for t, kde in zip(self.trees, kdes):
# t.kde = kde
# p.close()
# p.join()
@staticmethod
def compute_kde(tree):
"""
For a SpectralTree, computes the KDE.
# @staticmethod
# def compute_kde(tree):
# """
# For a SpectralTree, computes the KDE.
Args:
# Args:
tree (SpectralTree): tree for which the KDE will be computed
# tree (SpectralTree): tree for which the KDE will be computed
Returns:
# Returns:
the KDE function
"""
return tree.kde
# the KDE function
# """
# return tree.kde
def clustermap(self, cmap=None, figfile=None):
cm = sns.clustermap(self.distance_matrix, cmap=cmap)
if figfile is not None:
cm.savefig(figfile)
else:
plt.show()
# def clustermap(self, cmap=None, figfile=None):
# cm = sns.clustermap(self.distance_matrix, cmap=cmap)
# if figfile is not None:
# cm.savefig(figfile)
# else:
# plt.show()
class SpectralTree:
def __init__(self, tree_file, subject=None):
self.subject = subject
self.name = os.path.basename(tree_file)
self.tree_file = tree_file
self.t = ete3.Tree(tree_file)
# class SpectralTree:
# def __init__(self, tree_file, subject=None):
# self.subject = subject
# self.name = os.path.basename(tree_file)
# self.tree_file = tree_file
# self.t = ete3.Tree(tree_file)
@lazy_property
def nodes(self):
num = 0
nodes = []
for n in self.t.traverse():
if not n.name:
n.name = str(num)
num += 1
nodes.append(n)
return nodes
# @lazy_property
# def nodes(self):
# num = 0
# nodes = []
# for n in self.t.traverse():
# if not n.name:
# n.name = str(num)
# num += 1
# nodes.append(n)
# return nodes
@lazy_property
def node_names(self):
return [n.name for n in self.nodes]
# @lazy_property
# def node_names(self):
# return [n.name for n in self.nodes]
@lazy_property
def distances(self):
distances = []
for n1 in self.nodes:
dist = {}
for n2 in self.nodes:
if n1 == n2:
continue
dist[n2] = self.t.get_distance(n1, n2)
dist[n1] = -1.0 * sum(dist.values())
distances.append([dist[n] for n in self.nodes])
return np.asarray(distances)
# @lazy_property
# def distances(self):
# distances = []
# for n1 in self.nodes:
# dist = {}
# for n2 in self.nodes:
# if n1 == n2:
# continue
# dist[n2] = self.t.get_distance(n1, n2)
# dist[n1] = -1.0 * sum(dist.values())
# distances.append([dist[n] for n in self.nodes])
# return np.asarray(distances)
@lazy_property
def eigenvalues(self):
gl = scipy.sparse.csgraph.laplacian(self.distances)
w, v = np.linalg.eig(gl)
return [e for e in sorted(w, reverse=True) if e >= 1]
# @lazy_property
# def eigenvalues(self):
# gl = scipy.sparse.csgraph.laplacian(self.distances)
# w, v = np.linalg.eig(gl)
# return [e for e in sorted(w, reverse=True) if e >= 1]
@lazy_property
def log_eigenvalues(self):
return np.real(np.log(self.eigenvalues))
# @lazy_property
# def log_eigenvalues(self):
# return np.real(np.log(self.eigenvalues))
@property
def principal_eigenvalue(self):
return max(self.eigenvalues)
# @property
# def principal_eigenvalue(self):
# return max(self.eigenvalues)
@lazy_property
def kde(self):
return scipy.stats.gaussian_kde(self.eigenvalues)
# @lazy_property
# def kde(self):
# return scipy.stats.gaussian_kde(self.eigenvalues)
def pdf(self, x):
return [v.real for v in self.kde.pdf(x)]
# def pdf(self, x):
# return [v.real for v in self.kde.pdf(x)]
def plot(self, xs=None, figfile=None, xlim=None):
if xs is None:
xs = np.linspace(-0.5, self.principal_eigenvalue, 200)
ys = self.pdf(xs)
# initialize and plot
plt.figure(figsize=[3, 4])
plt.plot(xs, ys, lw=2)
plt.fill_between(xs, ys, alpha=0.2)
# style
ax = plt.gca()
ax.set_xlabel("eigenvalue (ln)")
ax.set_ylabel("density")
if xlim is not None:
ax.set_xlim(xlim)
if figfile is not None:
plt.tight_layout()
plt.savefig(figfile)
else:
plt.show()
# def plot(self, xs=None, figfile=None, xlim=None):
# if xs is None:
# xs = np.linspace(-0.5, self.principal_eigenvalue, 200)
# ys = self.pdf(xs)
# # initialize and plot
# plt.figure(figsize=[3, 4])
# plt.plot(xs, ys, lw=2)
# plt.fill_between(xs, ys, alpha=0.2)
# # style
# ax = plt.gca()
# ax.set_xlabel("eigenvalue (ln)")
# ax.set_ylabel("density")
# if xlim is not None:
# ax.set_xlim(xlim)
# if figfile is not None:
# plt.tight_layout()
# plt.savefig(figfile)
# else:
# plt.show()

@@ -26,286 +26,286 @@ #!/usr/bin/env python

import os
import subprocess as sp
import sys
import tarfile
# import os
# import subprocess as sp
# import sys
# import tarfile
from . import log
# from . import log
if sys.version_info[0] > 2:
raw_input = input
# if sys.version_info[0] > 2:
# raw_input = input
def compress_and_upload(
data,
compressed_file,
s3_path,
multipart_chunk_size_mb=500,
method="gz",
delete=False,
access_key=None,
secret_key=None,
):
"""
Compresses data and uploads to S3.
# def compress_and_upload(
# data,
# compressed_file,
# s3_path,
# multipart_chunk_size_mb=500,
# method="gz",
# delete=False,
# access_key=None,
# secret_key=None,
# ):
# """
# Compresses data and uploads to S3.
S3 upload uses ``s3cmd``, so you must either:
# S3 upload uses ``s3cmd``, so you must either:
1) Manually configure ``s3cmd`` prior to use (typically using ``s3cmd --configure``).
# 1) Manually configure ``s3cmd`` prior to use (typically using ``s3cmd --configure``).
2) Configure ``s3cmd`` using ``s3.configure()``.
# 2) Configure ``s3cmd`` using ``s3.configure()``.
3) Pass your access key and secret key to ``compress_and_upload``, which will automatically configure s3cmd.
# 3) Pass your access key and secret key to ``compress_and_upload``, which will automatically configure s3cmd.
.. note:
# .. note:
``s3cmd`` configuration only needs to be done once per computer,
which means that relaunching a cloud instance or Docker image will
require re-configuration of ``s3cmd``.
# ``s3cmd`` configuration only needs to be done once per computer,
# which means that relaunching a cloud instance or Docker image will
# require re-configuration of ``s3cmd``.
Args:
# Args:
data: Can be one of three things:
# data: Can be one of three things:
1) Path to a single file
# 1) Path to a single file
2) Path to a directory
# 2) Path to a directory
3) A list of one or more paths to files or directories
# 3) A list of one or more paths to files or directories
compressed_file (str): Path to the compressed file. Required.
# compressed_file (str): Path to the compressed file. Required.
s3_path (str): The S3 path, with the filename omitted. The S3 filename
will be the basename of the ``compressed_file``. For example::
# s3_path (str): The S3 path, with the filename omitted. The S3 filename
# will be the basename of the ``compressed_file``. For example::
compress_and_upload(data='/path/to/data',
compressed_file='/path/to/compressed.tar.gz',
s3_path='s3://my_bucket/path/to/')
# compress_and_upload(data='/path/to/data',
# compressed_file='/path/to/compressed.tar.gz',
# s3_path='s3://my_bucket/path/to/')
will result in an uploaded S3 path of ``s3://my_bucket/path/to/compressed.tar.gz``
# will result in an uploaded S3 path of ``s3://my_bucket/path/to/compressed.tar.gz``
method (str): Compression method. Options are ``'gz'`` (gzip) or ``'bz2'`` (bzip2).
Default is ``'gz'``.
# method (str): Compression method. Options are ``'gz'`` (gzip) or ``'bz2'`` (bzip2).
# Default is ``'gz'``.
delete (bool): If ``True``, the ``compressed_file`` will be deleted after upload
to S3. Default is ``False``.
# delete (bool): If ``True``, the ``compressed_file`` will be deleted after upload
# to S3. Default is ``False``.
access_key (str): AWS access key.
# access_key (str): AWS access key.
secret_key (str): AWS secret key.
"""
logger = log.get_logger("s3")
if all([access_key, secret_key]):
configure(access_key=access_key, secret_key=secret_key, logger=logger)
compress(data, compressed_file, fmt=method, logger=logger)
put(
compressed_file,
s3_path,
multipart_chunk_size_mb=multipart_chunk_size_mb,
logger=logger,
)
if delete:
os.unlink(compressed_file)
# secret_key (str): AWS secret key.
# """
# logger = log.get_logger("s3")
# if all([access_key, secret_key]):
# configure(access_key=access_key, secret_key=secret_key, logger=logger)
# compress(data, compressed_file, fmt=method, logger=logger)
# put(
# compressed_file,
# s3_path,
# multipart_chunk_size_mb=multipart_chunk_size_mb,
# logger=logger,
# )
# if delete:
# os.unlink(compressed_file)
def put(f, s3_path, multipart_chunk_size_mb=500, logger=None):
"""
Uploads a single file to S3, using s3cmd.
# def put(f, s3_path, multipart_chunk_size_mb=500, logger=None):
# """
# Uploads a single file to S3, using s3cmd.
Args:
# Args:
f (str): Path to a single file.
# f (str): Path to a single file.
s3_path (str): The S3 path, with the filename omitted. The S3 filename
will be the basename of the ``f``. For example::
# s3_path (str): The S3 path, with the filename omitted. The S3 filename
# will be the basename of the ``f``. For example::
put(f='/path/to/myfile.tar.gz', s3_path='s3://my_bucket/path/to/')
# put(f='/path/to/myfile.tar.gz', s3_path='s3://my_bucket/path/to/')
will result in an uploaded S3 path of ``s3://my_bucket/path/to/myfile.tar.gz``
"""
if not logger:
logger = log.get_logger("s3")
fname = os.path.basename(f)
target = os.path.join(s3_path, fname)
s3cmd_cline = "s3cmd put {} {} --multipart-chunk-size-mb {}".format(
f, target, multipart_chunk_size_mb
)
print_put_info(fname, target, logger)
s3cmd = sp.Popen(s3cmd_cline, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
stdout, stderr = s3cmd.communicate()
# will result in an uploaded S3 path of ``s3://my_bucket/path/to/myfile.tar.gz``
# """
# if not logger:
# logger = log.get_logger("s3")
# fname = os.path.basename(f)
# target = os.path.join(s3_path, fname)
# s3cmd_cline = "s3cmd put {} {} --multipart-chunk-size-mb {}".format(
# f, target, multipart_chunk_size_mb
# )
# print_put_info(fname, target, logger)
# s3cmd = sp.Popen(s3cmd_cline, stdout=sp.PIPE, stderr=sp.PIPE, shell=True)
# stdout, stderr = s3cmd.communicate()
def print_put_info(fname, target, logger):
logger.info("")
logger.info("")
logger.info("")
logger.info("-" * 25)
logger.info("UPLOADING TO S3")
logger.info("-" * 25)
logger.info("")
logger.info("File: {}".format(fname))
logger.info("Target S3 location: {}".format(target))
# def print_put_info(fname, target, logger):
# logger.info("")
# logger.info("")
# logger.info("")
# logger.info("-" * 25)
# logger.info("UPLOADING TO S3")
# logger.info("-" * 25)
# logger.info("")
# logger.info("File: {}".format(fname))
# logger.info("Target S3 location: {}".format(target))
def compress(d, output, fmt="gz", logger=None):
"""
Creates a compressed/uncompressed tar file.
# def compress(d, output, fmt="gz", logger=None):
# """
# Creates a compressed/uncompressed tar file.
Args:
# Args:
d: Can be one of three things:
# d: Can be one of three things:
1. the path to a single file, as a string
# 1. the path to a single file, as a string
2. the path to a single directory, as a string
# 2. the path to a single directory, as a string
3. an iterable of file or directory paths
# 3. an iterable of file or directory paths
output (str): Output file path.
# output (str): Output file path.
fmt: Compression method. Options are ``'gz'`` (gzip),
``'bz2'`` (bzip2) and ``'none'`` (uncompressed). Default is ``'gz'``.
"""
if not logger:
logger = log.get_logger("s3")
if type(d) not in [list, tuple]:
d = [
d,
]
d = [os.path.expanduser(_d) for _d in d]
print_compress_info(d, output, compress, logger)
if fmt.lower() == "none":
fmt = ""
elif fmt.lower() not in ["gz", "bz2"]:
logger.info(
'Compression option ("{}") is invalid.\nFalling back to uncompressed.'.format(
fmt
)
)
fmt = ""
output = os.path.expanduser(output)
tar = tarfile.open(output, "w:{}".format(fmt))
for obj in d:
tar.add(obj)
tar.close()
return output
# fmt: Compression method. Options are ``'gz'`` (gzip),
# ``'bz2'`` (bzip2) and ``'none'`` (uncompressed). Default is ``'gz'``.
# """
# if not logger:
# logger = log.get_logger("s3")
# if type(d) not in [list, tuple]:
# d = [
# d,
# ]
# d = [os.path.expanduser(_d) for _d in d]
# print_compress_info(d, output, compress, logger)
# if fmt.lower() == "none":
# fmt = ""
# elif fmt.lower() not in ["gz", "bz2"]:
# logger.info(
# 'Compression option ("{}") is invalid.\nFalling back to uncompressed.'.format(
# fmt
# )
# )
# fmt = ""
# output = os.path.expanduser(output)
# tar = tarfile.open(output, "w:{}".format(fmt))
# for obj in d:
# tar.add(obj)
# tar.close()
# return output
def print_compress_info(d, output, compress, logger):
if not logger:
logger = log.get_logger("s3")
dirs = [obj for obj in d if os.path.isdir(obj)]
files = [obj for obj in d if os.path.isfile(obj)]
logger.info("")
logger.info("")
logger.info("")
logger.info("-" * 25)
logger.info("COMPRESSING DATA")
logger.info("-" * 25)
logger.info("")
logger.info("Ouptut file: {}".format(output))
logger.info("Compression: {}".format(compress.lower()))
if dirs:
d = "directories" if len(dirs) > 1 else "directory"
logger.info("Found {} {} to compress: {}".format(len(dirs), d, ", ".join(dirs)))
if files:
f = "files" if len(files) > 1 else "file"
logger.info(
"Found {} {} to compress: {}".format(len(files), f, ", ".join(files))
)
# def print_compress_info(d, output, compress, logger):
# if not logger:
# logger = log.get_logger("s3")
# dirs = [obj for obj in d if os.path.isdir(obj)]
# files = [obj for obj in d if os.path.isfile(obj)]
# logger.info("")
# logger.info("")
# logger.info("")
# logger.info("-" * 25)
# logger.info("COMPRESSING DATA")
# logger.info("-" * 25)
# logger.info("")
# logger.info("Ouptut file: {}".format(output))
# logger.info("Compression: {}".format(compress.lower()))
# if dirs:
# d = "directories" if len(dirs) > 1 else "directory"
# logger.info("Found {} {} to compress: {}".format(len(dirs), d, ", ".join(dirs)))
# if files:
# f = "files" if len(files) > 1 else "file"
# logger.info(
# "Found {} {} to compress: {}".format(len(files), f, ", ".join(files))
# )
def configure(access_key=None, secret_key=None, logger=None):
"""
Configures s3cmd prior to first use.
# def configure(access_key=None, secret_key=None, logger=None):
# """
# Configures s3cmd prior to first use.
If no arguments are provided, you will be prompted to enter
the access key and secret key interactively.
# If no arguments are provided, you will be prompted to enter
# the access key and secret key interactively.
Args:
# Args:
access_key (str): AWS access key
# access_key (str): AWS access key
secret_key (str): AWS secret key
"""
if not logger:
logger = log.get_logger("s3")
if not all([access_key, secret_key]):
logger.info("")
access_key = input("AWS Access Key: ")
secret_key = input("AWS Secret Key: ")
_write_config(access_key, secret_key)
logger.info("")
logger.info("Completed writing S3 config file.")
logger.info("")
# secret_key (str): AWS secret key
# """
# if not logger:
# logger = log.get_logger("s3")
# if not all([access_key, secret_key]):
# logger.info("")
# access_key = input("AWS Access Key: ")
# secret_key = input("AWS Secret Key: ")
# _write_config(access_key, secret_key)
# logger.info("")
# logger.info("Completed writing S3 config file.")
# logger.info("")
def _write_config(access_key, secret_key):
cfg_string = "[default]\n"
cfg_string += "access_key = {}\n".format(access_key)
cfg_string += "secret_key = {}\n".format(secret_key)
cfg_string += CONFIG_DEFAULTS
cfg_file = os.path.expanduser("~/.s3cfg")
open(cfg_file, "w").write(cfg_string)
# def _write_config(access_key, secret_key):
# cfg_string = "[default]\n"
# cfg_string += "access_key = {}\n".format(access_key)
# cfg_string += "secret_key = {}\n".format(secret_key)
# cfg_string += CONFIG_DEFAULTS
# cfg_file = os.path.expanduser("~/.s3cfg")
# open(cfg_file, "w").write(cfg_string)
CONFIG_DEFAULTS = """
access_token =
add_encoding_exts =
add_headers =
bucket_location = US
cache_file =
cloudfront_host = cloudfront.amazonaws.com
default_mime_type = binary/octet-stream
delay_updates = False
delete_after = False
delete_after_fetch = False
delete_removed = False
dry_run = False
enable_multipart = True
encoding = UTF-8
encrypt = False
expiry_date =
expiry_days =
expiry_prefix =
follow_symlinks = False
force = False
get_continue = False
gpg_command = /usr/bin/gpg
gpg_decrypt = %(gpg_command)s -d --verbose --no-use-agent --batch --yes --passphrase-fd %(passphrase_fd)s -o %(output_file)s %(input_file)s
gpg_encrypt = %(gpg_command)s -c --verbose --no-use-agent --batch --yes --passphrase-fd %(passphrase_fd)s -o %(output_file)s %(input_file)s
gpg_passphrase =
guess_mime_type = True
host_base = s3.amazonaws.com
host_bucket = %(bucket)s.s3.amazonaws.com
human_readable_sizes = False
ignore_failed_copy = False
invalidate_default_index_on_cf = False
invalidate_default_index_root_on_cf = True
invalidate_on_cf = False
list_md5 = False
log_target_prefix =
max_delete = -1
mime_type =
multipart_chunk_size_mb = 250
preserve_attrs = True
progress_meter = True
proxy_host =
proxy_port = 0
put_continue = False
recursive = False
recv_chunk = 4096
reduced_redundancy = False
restore_days = 1
send_chunk = 4096
server_side_encryption = False
simpledb_host = sdb.amazonaws.com
skip_existing = False
socket_timeout = 300
urlencoding_mode = normal
use_https = True
use_mime_magic = True
verbosity = WARNING
website_endpoint = http://%(bucket)s.s3-website-% (location)s.amazonaws.com/
website_error =
website_index = index.html
"""
# CONFIG_DEFAULTS = """
# access_token =
# add_encoding_exts =
# add_headers =
# bucket_location = US
# cache_file =
# cloudfront_host = cloudfront.amazonaws.com
# default_mime_type = binary/octet-stream
# delay_updates = False
# delete_after = False
# delete_after_fetch = False
# delete_removed = False
# dry_run = False
# enable_multipart = True
# encoding = UTF-8
# encrypt = False
# expiry_date =
# expiry_days =
# expiry_prefix =
# follow_symlinks = False
# force = False
# get_continue = False
# gpg_command = /usr/bin/gpg
# gpg_decrypt = %(gpg_command)s -d --verbose --no-use-agent --batch --yes --passphrase-fd %(passphrase_fd)s -o %(output_file)s %(input_file)s
# gpg_encrypt = %(gpg_command)s -c --verbose --no-use-agent --batch --yes --passphrase-fd %(passphrase_fd)s -o %(output_file)s %(input_file)s
# gpg_passphrase =
# guess_mime_type = True
# host_base = s3.amazonaws.com
# host_bucket = %(bucket)s.s3.amazonaws.com
# human_readable_sizes = False
# ignore_failed_copy = False
# invalidate_default_index_on_cf = False
# invalidate_default_index_root_on_cf = True
# invalidate_on_cf = False
# list_md5 = False
# log_target_prefix =
# max_delete = -1
# mime_type =
# multipart_chunk_size_mb = 250
# preserve_attrs = True
# progress_meter = True
# proxy_host =
# proxy_port = 0
# put_continue = False
# recursive = False
# recv_chunk = 4096
# reduced_redundancy = False
# restore_days = 1
# send_chunk = 4096
# server_side_encryption = False
# simpledb_host = sdb.amazonaws.com
# skip_existing = False
# socket_timeout = 300
# urlencoding_mode = normal
# use_https = True
# use_mime_magic = True
# verbosity = WARNING
# website_endpoint = http://%(bucket)s.s3-website-% (location)s.amazonaws.com/
# website_error =
# website_index = index.html
# """

@@ -26,308 +26,308 @@ #!/usr/bin/env python

import json
import os
import sys
from abc import ABCMeta, abstractmethod
# import json
# import os
# import sys
# from abc import ABCMeta, abstractmethod
from Bio import SeqIO
# from Bio import SeqIO
from ..core.sequence import Sequence
from ..io import list_files
from . import mongodb
from .decorators import lazy_property
# from ..core.sequence import Sequence
# from ..io import list_files
# from . import mongodb
# from .decorators import lazy_property
# if sys.version_info[0] > 2:
# STR_TYPES = [
# str,
# ]
# else:
# STR_TYPES = [str, unicode]
# # if sys.version_info[0] > 2:
# # STR_TYPES = [
# # str,
# # ]
# # else:
# # STR_TYPES = [str, unicode]
# def read_input(input, data_type,
# collection=None, mongo_ip='localhost', mongo_port=27017, mongo_user=None, mongo_password=None,
# query=None, projection=None, verbose=False, **kwargs):
# '''
# Returns an Input class based on the data information provided.
# # def read_input(input, data_type,
# # collection=None, mongo_ip='localhost', mongo_port=27017, mongo_user=None, mongo_password=None,
# # query=None, projection=None, verbose=False, **kwargs):
# # '''
# # Returns an Input class based on the data information provided.
# Args:
# # Args:
# data_type (str): One of the following: `'fasta'`, `'json'`, or `'mongodb'`.
# # data_type (str): One of the following: `'fasta'`, `'json'`, or `'mongodb'`.
# input (str): Path to an input file for FASTA or JSON data types, or the database name for MongoDB data.
# # input (str): Path to an input file for FASTA or JSON data types, or the database name for MongoDB data.
# collection (str): Name of a MongoDB collection. Required for the MongoDB data type.
# # collection (str): Name of a MongoDB collection. Required for the MongoDB data type.
# mongo_ip (str): IP address of the MongoDB server. Default is `'localhost'` if not provided.
# # mongo_ip (str): IP address of the MongoDB server. Default is `'localhost'` if not provided.
# mongo_port (int): Port of the MongoDB server. Default is `27017` if not provided.
# # mongo_port (int): Port of the MongoDB server. Default is `27017` if not provided.
# query (dict): Query to limit the results returned from a MongoDB database.
# # query (dict): Query to limit the results returned from a MongoDB database.
# projection (dict): Projection to specify fields to be retuired from a MongoDB database.
# '''
# if data_type.lower() == 'mongodb':
# return MongoDBInput(database=input, collection=collection, ip=mongo_ip, port=mongo_port,
# user=mongo_user, password=mongo_password, query=query, projection=projection)
# elif data_type.lower() == 'json':
# return JSONInput(input)
# elif data_type.lower() == 'fasta':
# return FASTAInput(input)
# else:
# err = '\n\nERROR: data_type must be one of the following:\n'
# err += 'json, mongodb, or fasta\n\n'
# print(err)
# sys.exit(1)
# # projection (dict): Projection to specify fields to be retuired from a MongoDB database.
# # '''
# # if data_type.lower() == 'mongodb':
# # return MongoDBInput(database=input, collection=collection, ip=mongo_ip, port=mongo_port,
# # user=mongo_user, password=mongo_password, query=query, projection=projection)
# # elif data_type.lower() == 'json':
# # return JSONInput(input)
# # elif data_type.lower() == 'fasta':
# # return FASTAInput(input)
# # else:
# # err = '\n\nERROR: data_type must be one of the following:\n'
# # err += 'json, mongodb, or fasta\n\n'
# # print(err)
# # sys.exit(1)
def read_fasta(fasta_file, verbose=False):
return FASTAInput(fasta_file, verbose=verbose)
# def read_fasta(fasta_file, verbose=False):
# return FASTAInput(fasta_file, verbose=verbose)
def from_fasta(fasta_file, verbose=False):
return FASTAInput(fasta_file, verbose=verbose)
# def from_fasta(fasta_file, verbose=False):
# return FASTAInput(fasta_file, verbose=verbose)
def from_json(json_file, seq_field="vdj_nt", verbose=False):
return JSONInput(json_file, seq_field=seq_field, verbose=verbose)
# def from_json(json_file, seq_field="vdj_nt", verbose=False):
# return JSONInput(json_file, seq_field=seq_field, verbose=verbose)
def from_mongodb(
db,
collection=None,
ip="localhost",
port=27017,
user=None,
password=None,
query=None,
projection=None,
seq_field="vdj_nt",
verbose=False,
):
return MongoDBInput(
database=db,
collection=collection,
ip=ip,
port=port,
user=user,
password=password,
query=query,
projection=projection,
seq_field=seq_field,
verbose=verbose,
)
# def from_mongodb(
# db,
# collection=None,
# ip="localhost",
# port=27017,
# user=None,
# password=None,
# query=None,
# projection=None,
# seq_field="vdj_nt",
# verbose=False,
# ):
# return MongoDBInput(
# database=db,
# collection=collection,
# ip=ip,
# port=port,
# user=user,
# password=password,
# query=query,
# projection=projection,
# seq_field=seq_field,
# verbose=verbose,
# )
class BaseInput:
"""
Base class for parsing inputs (JSON, MongoDB, FASTA files, etc)
"""
# class BaseInput:
# """
# Base class for parsing inputs (JSON, MongoDB, FASTA files, etc)
# """
__metaclass__ = ABCMeta
# __metaclass__ = ABCMeta
def __init__(self):
pass
# def __init__(self):
# pass
@property
@abstractmethod
def data_type(self):
"Returns the data type"
pass
# @property
# @abstractmethod
# def data_type(self):
# "Returns the data type"
# pass
@property
@abstractmethod
def as_list(self):
"Returns the input as a list of Sequence objects"
pass
# @property
# @abstractmethod
# def as_list(self):
# "Returns the input as a list of Sequence objects"
# pass
@property
@abstractmethod
def as_generator(self):
"Returns the input as a genarator of Sequence objects"
pass
# @property
# @abstractmethod
# def as_generator(self):
# "Returns the input as a genarator of Sequence objects"
# pass
class FASTAInput(BaseInput):
"""
Representation of FASTA input data.
"""
# class FASTAInput(BaseInput):
# """
# Representation of FASTA input data.
# """
def __init__(self, data, verbose=False):
self.input = data
self.verbose = verbose
# def __init__(self, data, verbose=False):
# self.input = data
# self.verbose = verbose
@property
def data_type(self):
return "fasta"
# @property
# def data_type(self):
# return "fasta"
@property
def files(self):
if isinstance(self.input, str):
if os.path.isdir(self.input):
return list_files(self.input, "json")
else:
return [
self.input,
]
else:
return self.input
# @property
# def files(self):
# if isinstance(self.input, str):
# if os.path.isdir(self.input):
# return list_files(self.input, "json")
# else:
# return [
# self.input,
# ]
# else:
# return self.input
@lazy_property
def as_list(self):
sequences = []
for input_file in self.files:
if self.verbose:
print(input_file)
with open(input_file, "r") as f:
for seq in SeqIO.parse(f, "fasta"):
sequences.append(Sequence(str(seq.seq), id=seq.id))
return sequences
# @lazy_property
# def as_list(self):
# sequences = []
# for input_file in self.files:
# if self.verbose:
# print(input_file)
# with open(input_file, "r") as f:
# for seq in SeqIO.parse(f, "fasta"):
# sequences.append(Sequence(str(seq.seq), id=seq.id))
# return sequences
@property
def as_generator(self):
for input_file in self.files:
if self.verbose:
print(input_file)
with open(input_file, "r") as f:
for seq in SeqIO.parse(f, "fasta"):
yield Sequence(str(seq.seq), id=seq.id)
# @property
# def as_generator(self):
# for input_file in self.files:
# if self.verbose:
# print(input_file)
# with open(input_file, "r") as f:
# for seq in SeqIO.parse(f, "fasta"):
# yield Sequence(str(seq.seq), id=seq.id)
class JSONInput(BaseInput):
"""
Representation of JSON input data
"""
# class JSONInput(BaseInput):
# """
# Representation of JSON input data
# """
def __init__(self, data, seq_field="vdj_nt", verbose=False):
self.input = data
self.seq_field = seq_field
self.verbose = verbose
# def __init__(self, data, seq_field="vdj_nt", verbose=False):
# self.input = data
# self.seq_field = seq_field
# self.verbose = verbose
@property
def data_type(self):
return "json"
# @property
# def data_type(self):
# return "json"
@property
def files(self):
if isinstance(self.input, str):
if os.path.isdir(self.input):
return list_files(self.input, "json")
else:
return [
self.input,
]
else:
return self.input
# @property
# def files(self):
# if isinstance(self.input, str):
# if os.path.isdir(self.input):
# return list_files(self.input, "json")
# else:
# return [
# self.input,
# ]
# else:
# return self.input
@lazy_property
def as_list(self):
sequences = []
for input_file in self.files:
if self.verbose:
print(input_file)
with open(input_file, "r") as f:
for line in f:
j = json.loads(line.strip().lstrip("]").rstrip("]").rstrip(","))
sequences.append(Sequence(j, seq_key=self.seq_field))
return sequences
# @lazy_property
# def as_list(self):
# sequences = []
# for input_file in self.files:
# if self.verbose:
# print(input_file)
# with open(input_file, "r") as f:
# for line in f:
# j = json.loads(line.strip().lstrip("]").rstrip("]").rstrip(","))
# sequences.append(Sequence(j, seq_key=self.seq_field))
# return sequences
@property
def as_generator(self):
for input_file in self.files:
if self.verbose:
print(input_file)
with open(input_file, "r") as f:
for line in f:
j = json.loads(
line.strip().lstrip("[").rstrip("]").rstrip().rstrip(",")
)
yield Sequence(j, seq_key=self.seq_field)
# @property
# def as_generator(self):
# for input_file in self.files:
# if self.verbose:
# print(input_file)
# with open(input_file, "r") as f:
# for line in f:
# j = json.loads(
# line.strip().lstrip("[").rstrip("]").rstrip().rstrip(",")
# )
# yield Sequence(j, seq_key=self.seq_field)
class MongoDBInput(BaseInput):
"""
Representation of MongoDB input data
"""
# class MongoDBInput(BaseInput):
# """
# Representation of MongoDB input data
# """
def __init__(
self,
database,
collection,
ip,
port,
user,
password,
query,
projection,
seq_field="vdj_nt",
verbose=False,
):
self.db_name = database
self.raw_collections = collection
self.ip = ip
self.port = port
self.user = user
self.password = password
self.query = query
self.projection = projection
self.seq_field = seq_field
self.verbose = verbose
# def __init__(
# self,
# database,
# collection,
# ip,
# port,
# user,
# password,
# query,
# projection,
# seq_field="vdj_nt",
# verbose=False,
# ):
# self.db_name = database
# self.raw_collections = collection
# self.ip = ip
# self.port = port
# self.user = user
# self.password = password
# self.query = query
# self.projection = projection
# self.seq_field = seq_field
# self.verbose = verbose
@property
def data_type(self):
return "mongodb"
# @property
# def data_type(self):
# return "mongodb"
@property
def db(self):
return mongodb.get_db(
self.db_name,
ip=self.ip,
port=self.port,
user=self.user,
password=self.password,
)
# @property
# def db(self):
# return mongodb.get_db(
# self.db_name,
# ip=self.ip,
# port=self.port,
# user=self.user,
# password=self.password,
# )
@property
def collections(self):
if isinstance(self.raw_collections, str):
return [
self.raw_collections,
]
elif self.raw_collections is None:
return mongodb.get_collections(self.db)
else:
return self.raw_collections
# @property
# def collections(self):
# if isinstance(self.raw_collections, str):
# return [
# self.raw_collections,
# ]
# elif self.raw_collections is None:
# return mongodb.get_collections(self.db)
# else:
# return self.raw_collections
@lazy_property
def as_list(self):
sequences = []
for collection in self.collections:
if self.verbose:
print(collection)
res = self.db[collection].find(self.query, self.projection)
for r in res:
if self.seq_field not in r:
continue
sequences.append(Sequence(r, seq_key=self.seq_field))
return sequences
# @lazy_property
# def as_list(self):
# sequences = []
# for collection in self.collections:
# if self.verbose:
# print(collection)
# res = self.db[collection].find(self.query, self.projection)
# for r in res:
# if self.seq_field not in r:
# continue
# sequences.append(Sequence(r, seq_key=self.seq_field))
# return sequences
@property
def as_generator(self):
for collection in self.collections:
if self.verbose:
print(collection)
res = self.db[collection].find(self.query, self.projection)
for r in res:
if self.seq_field not in r:
continue
yield Sequence(r, seq_key=self.seq_field)
# @property
# def as_generator(self):
# for collection in self.collections:
# if self.verbose:
# print(collection)
# res = self.db[collection].find(self.query, self.projection)
# for r in res:
# if self.seq_field not in r:
# continue
# yield Sequence(r, seq_key=self.seq_field)
def _process_collections(self, collection):
if isinstance(collection, str):
return [
collection,
]
elif collection is None:
return mongodb.get_collections(self.db)
else:
return collection
# def _process_collections(self, collection):
# if isinstance(collection, str):
# return [
# collection,
# ]
# elif collection is None:
# return mongodb.get_collections(self.db)
# else:
# return collection

@@ -6,2 +6,2 @@ # Store the version here so:

__version__ = "0.5.0"
__version__ = "0.5.1"
Metadata-Version: 2.4
Name: abutils
Version: 0.5.0
Version: 0.5.1
Summary: Utilities for analysis of adaptive immune receptor repertoire (AIRR) data

@@ -13,8 +13,8 @@ Home-page: https://github.com/briney/abutils

Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Programming Language :: Python :: 3.10
Classifier: Programming Language :: Python :: 3.11
Classifier: Programming Language :: Python :: 3.12
Classifier: Programming Language :: Python :: 3.13
Classifier: Topic :: Scientific/Engineering :: Bio-Informatics
Requires-Python: >=3.8
Requires-Python: >=3.10
Description-Content-Type: text/markdown

@@ -25,5 +25,3 @@ License-File: LICENSE

Requires-Dist: biopython>=1.78
Requires-Dist: celery
Requires-Dist: dnachisel
Requires-Dist: ete3
Requires-Dist: fastcluster

@@ -42,3 +40,2 @@ Requires-Dist: matplotlib

Requires-Dist: pyfastx
Requires-Dist: pymongo
Requires-Dist: pytest

@@ -106,3 +103,3 @@ Requires-Dist: python-circos

### requirements
**python 3.8+**
**python 3.10+**

@@ -109,0 +106,0 @@ abstar

@@ -42,3 +42,3 @@ ![](https://img.shields.io/pypi/v/abutils.svg?colorB=blue)

### requirements
**python 3.8+**
**python 3.10+**

@@ -45,0 +45,0 @@ abstar

abstar>=0.6.3
baltic
biopython >= 1.78
celery
# celery
dnachisel
ete3
# ete3
fastcluster

@@ -20,3 +20,3 @@ matplotlib

pyfastx
pymongo
# pymongo
pytest

@@ -23,0 +23,0 @@ python-circos

@@ -97,11 +97,9 @@ # import os

"Programming Language :: Python :: 3",
# 'Programming Language :: Python :: 3.6',
# 'Programming Language :: Python :: 3.7',
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Bio-Informatics",
],
python_requires=">=3.8",
python_requires=">=3.10",
)
#!/usr/bin/env python
# filename: pipeline.py
#
# Copyright (c) 2015 Bryan Briney
# License: The MIT license (http://opensource.org/licenses/MIT)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software
# and associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
# from __future__ import absolute_import, division, print_function, unicode_literals
import glob
import os
import sys
from typing import Iterable, Optional, Union
from . import log
# for backward compatibility
def list_files(*args, **kwargs):
from ..io import list_files
return list_files(*args, **kwargs)
def make_dir(*args, **kwargs):
from ..io import make_dir
return make_dir(*args, **kwargs)
def initialize(log_file, project_dir=None, debug=False):
"""
Initializes an AbTools pipeline.
Initialization includes printing the AbTools splash, setting up logging,
creating the project directory, and logging both the project directory
and the log location.
Parameters
----------
log_file : str
Path to the log file. Required.
project_dir : str
Path to the project directory. If not provided,
the project directory won't be created and the location won't be logged.
debug : bool
If ``True``, the logging level will be set to ``logging.DEBUG``.
Returns
-------
logger
A logger instance.
"""
print_splash()
log.setup_logging(log_file, print_log_location=False, debug=debug)
logger = log.get_logger("pipeline")
if project_dir is not None:
make_dir(os.path.normpath(project_dir))
logger.info("PROJECT DIRECTORY: {}".format(project_dir))
logger.info("")
logger.info("LOG LOCATION: {}".format(log_file))
print("")
return logger
# def make_dir(directory: str) -> None:
# """
# Makes a directory, if it doesn't already exist.
# Parameters
# ----------
# directory : str
# Path to a directory.
# """
# if not os.path.exists(directory):
# os.makedirs(os.path.abspath(directory))
# def list_files(
# directory: str, extension: Union[str, Iterable, None] = None
# ) -> Iterable[str]:
# """
# Lists files in a given directory.
# Parameters
# ----------
# directory : str
# Path to a directory.
# extension : str
# If supplied, only files that end with the specificied extension(s) will be returned. Can be either
# a string or a list of strings. Extension evaluation is case-insensitive and can match complex
# extensions (e.g. '.fastq.gz'). Default is ``None``, which returns all files in the directory,
# regardless of extension.
# Returns
# -------
# Iterable[str]
# """
# if os.path.isdir(directory):
# expanded_dir = os.path.expanduser(directory)
# files = sorted(glob.glob(expanded_dir + "/*"))
# else:
# files = [
# directory,
# ]
# if extension is not None:
# if isinstance(extension, str):
# extension = [
# extension,
# ]
# files = [
# f
# for f in files
# if any(
# [
# any([f.lower().endswith(e.lower()) for e in extension]),
# any([f.endswith(e.upper()) for e in extension]),
# any([f.endswith(e.lower()) for e in extension]),
# ]
# )
# ]
# return files
def print_splash():
splash = """
_ _ _ _ _ _ _ ____ _ _ _
/ \ | |__ | | | | |_(_) |___ | _ \(_)_ __ ___| (_)_ __ ___
/ _ \ | '_ \| | | | __| | / __| | |_) | | '_ \ / _ \ | | '_ \ / _ \\
/ ___ \| |_) | |_| | |_| | \__ \ | __/| | |_) | __/ | | | | | __/
/_/ \_\_.__/ \___/ \__|_|_|___/ |_| |_| .__/ \___|_|_|_| |_|\___|
|_|
"""
print("")
print(splash)
print(
"(c) 2023 Bryan Briney\nDistributed under the MIT License (http://opensource.org/licenses/MIT)"
)
print("")

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is too big to display