Huge News!Announcing our $40M Series B led by Abstract Ventures.Learn More
Socket
Sign inDemoInstall
Socket

torch-cif

Package Overview
Dependencies
Maintainers
1
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

torch-cif

A fast parallel implementation of continuous integrate-and-fire (CIF) https://arxiv.org/abs/1905.11235

  • 0.2.0
  • PyPI
  • Socket score

Maintainers
1

torch-cif

A fast parallel implementation pure PyTorch implementation of "CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition" https://arxiv.org/abs/1905.11235.

Installation

PyPI

pip install torch-cif

Locally

git clone https://github.com/George0828Zhang/torch_cif
cd torch_cif
python setup.py install

Usage

def cif_function(
    inputs: Tensor,
    alpha: Tensor,
    beta: float = 1.0,
    tail_thres: float = 0.5,
    padding_mask: Optional[Tensor] = None,
    target_lengths: Optional[Tensor] = None,
    eps: float = 1e-4,
    unbound_alpha: bool = False
) -> Dict[str, List[Tensor]]:
    r""" A fast parallel implementation of continuous integrate-and-fire (CIF)
    https://arxiv.org/abs/1905.11235

    Shapes:
        N: batch size
        S: source (encoder) sequence length
        C: source feature dimension
        T: target sequence length

    Args:
        inputs (Tensor): (N, S, C) Input features to be integrated.
        alpha (Tensor): (N, S) Weights corresponding to each elements in the
            inputs. It is expected to be after sigmoid function.
        beta (float): the threshold used for determine firing.
        tail_thres (float): the threshold for determine firing for tail handling.
        padding_mask (Tensor, optional): (N, S) A binary mask representing
            padded elements in the inputs. 1 is padding, 0 is not.
        target_lengths (Tensor, optional): (N,) Desired length of the targets
            for each sample in the minibatch.
        eps (float, optional): Epsilon to prevent underflow for divisions.
            Default: 1e-4
        unbound_alpha (bool, optional): Whether to check if 0 <= alpha <= 1.

    Returns -> Dict[str, List[Tensor]]: Key/values described below.
        cif_out: (N, T, C) The output integrated from the source.
        cif_lengths: (N,) The output length for each element in batch.
        alpha_sum: (N,) The sum of alpha for each element in batch.
            Can be used to compute the quantity loss.
        delays: (N, T) The expected delay (in terms of source tokens) for
            each target tokens in the batch.
        tail_weights: (N,) During inference, return the tail.
        scaled_alpha: (N, S) alpha after applying weight scaling.
        cumsum_alpha: (N, S) cumsum of alpha after scaling.
        right_indices: (N, S) right scatter indices, or floor(cumsum(alpha)).
        right_weights: (N, S) right scatter weights.
        left_indices: (N, S) left scatter indices.
        left_weights: (N, S) left scatter weights.
    """

Note

  • This implementation uses cumsum and floor to determine the firing positions, and use scatter to merge the weighted source features. The figure below demonstrates this concept using scaled weight sequence (0.4, 1.8, 1.2, 1.2, 1.4)
drawing
  • Runing test requires pip install hypothesis expecttest.
  • If beta != 1, our implementation slightly differ from Algorithm 1 in the paper [1]:
    • When a boundary is located, the original algorithm add the last feature to the current integration with weight 1 - accumulation (line 11 in Algorithm 1), which causes negative weights in next integration when alpha < 1 - accumulation.
    • We use beta - accumulation, which means the weight in next integration alpha - (beta - accumulation) is always positive.
  • Feel free to contact me if there are bugs in the code.

References

  1. CIF: Continuous Integrate-and-Fire for End-to-End Speech Recognition
  2. Exploring Continuous Integrate-and-Fire for Adaptive Simultaneous Speech Translation

Keywords

FAQs


Did you know?

Socket

Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.

Install

Related posts

SocketSocket SOC 2 Logo

Product

  • Package Alerts
  • Integrations
  • Docs
  • Pricing
  • FAQ
  • Roadmap
  • Changelog

Packages

npm

Stay in touch

Get open source security insights delivered straight into your inbox.


  • Terms
  • Privacy
  • Security

Made with ⚡️ by Socket Inc