Research
Security News
Malicious npm Packages Inject SSH Backdoors via Typosquatted Libraries
Socket’s threat research team has detected six malicious npm packages typosquatting popular libraries to insert SSH backdoors.
Automatic-Class-Balanced MSE Loss for PyTorch (ACB-MSE) to combat class imbalanced datasets and stabilise fluctuating loss gradients.
Automatic-Class-Balanced MSE Loss function for PyTorch (ACB-MSE) to combat class imbalanced datasets.
This repository contains the PyTorch implementation of the ACB-MSE loss function, which stands for Automatic Class Balanced Mean Squared Error, originally developed for the DEEPCLEAN3D Denoiser to combat class imbalance and stabilise loss gradient fluctuation due to dramatically varying class frequencies.
Available on PyPi
pip install acb_mse
zero_weighting
(float, optional): Weighting coefficient for MSE loss of zero pixels. Default is 1.nonzero_weighting
(float, optional): Weighting coefficient for MSE loss of non-zero pixels. Default is 1.import torch
from acb_mse import ACBLoss
# Select weighting for each class if not wanting to use the defualt 1:1 weighting
zero_weighting = 1.0
nonzero_weighting = 1.2
# Create an instance of the ACBMSE loss function with specified weighting coefficients
loss_function = ACBLoss(zero_weighting, nonzero_weighting)
# Dummy target image and reconstructed image tensors (assuming B=10, C=3, H=256, W=256)
target_image = torch.rand(10, 3, 256, 256)
reconstructed_image = torch.rand(10, 3, 256, 256)
# Calculate the ACBMSE loss
loss = loss_function(reconstructed_image, target_image)
print("ACB-MSE Loss:", loss)
zero_mask
: A boolean mask where elements are True
for zero-valued pixels in the target image.nonzero_mask
: A boolean mask where elements are True
for non-zero-valued pixels in the target image.zero_weighting
and nonzero_weighting
) to allow user to adjust the balance from default 1:1.The function relies on the knowledge of the indices for all hits and non-hits in the true label image, which are then compared to the values in the corresponding index's in the recovered image. Therefore, ACB-MSE is unsuitable for unsupervised learning tasks. The ACB-MSE loss function is given by:
$$ \text{Loss} = A(\frac{1}{N _ h}\sum _ {i = 1} ^ {N _ h}(y _ i - \hat{y} _ i) ^ 2) + B(\frac{1}{N _ n}\sum _ {i = 1} ^ {N _ n}(y _ i - \hat{y} _ i) ^ 2) $$
where $y_i$ is the true value of the $i$-th pixel in the class, $\hat{y}_i$ is the predicted value of the $i$-th pixel in the class, and $n$ is the total number of pixels in the class (in our case labeled as $N_h$ and $N_n$ corresponding to 'hits' and 'no hits' classes, but can be extended to n classes). This approach to the loss function calculation takes the mean square of each class separately, when summing the separate classes errors back together they are automatically scaled by the inverse of the class frequency, normalising the class balance to 1:1. The additional coefficients $A$ and $B$ allow the user to manually adjust the balance to fine tune the balance.
The ACB-MSE loss function was designed for data taken from particle detectors which often have a majority of 'pixels' which are unlit and a very sparse pattern of lit pixels. In this scenario the ACB-MSE loss provides two main benefits, addressing the class imbalance beteen lit and unlit pixels whilst also stabilising the loss gradient during training. Additonal parameters, 'A' and 'B', are provided to allow the user to set a custom balance between classes.
Fluctuations in the number of hit pixels across images during training can disrupt loss stability. ACB-MSE remedies this by dynamically adjusting loss function weights to reflect class frequencies in the target.
The above plot demonstrates how each of the loss functions (ACB-MSE, MSE and MAE) behave based on the number of hits in the true signal. Two dummy images were created, the first image contains a simulated signal and the recovered image is created with 50% of that signal correctly identified, simulating a 50% signal recovery by the network. To generate the plot the first image was filled in two pixel increments with the second image following at a constant 50% recovery, and at each iteration the loss is calculated for the pair of images. We can see how the MSE and MAE functions loss varies as the size of the signal is increased with the recovery percentage fixed at 50%, whereas the ACB-MSE loss stays constant regardless of the frequency of the signal class.
Class imbalance is an issue that can arise where the interesting features are contained in the minority class. In the case of the DEEPCLEAN3D data, the input images contained 11,264 total pixels with only around 200 of them being hits. For the network, guessing that all the pixels are non-hits (zero valued) yields a very respectable reconstruction loss and is a simple transfer function for the network to learn, this local minima proved hard for the network to escape from. Class balancing based on class frequency is a simple solution to this problem that shifts the loss landscape, making it less favorable for the network to guess all pixels as non-hits. This enabled the DEEPCLEAN3D network to escape the local minima and begin to learn a usefull transfer function for the input fetures.
This project is licensed under the MIT License - see the LICENSE.md file for details.
Contributions to this codebase are welcome! If you encounter any issues or have suggestions for improvements please open an issue or a pull request on the GitHub repository.
For any inquiries, feel free to reach out to me at adillwmaa@gmail.com.
FAQs
Automatic-Class-Balanced MSE Loss for PyTorch (ACB-MSE) to combat class imbalanced datasets and stabilise fluctuating loss gradients.
We found that acb-mse demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Research
Security News
Socket’s threat research team has detected six malicious npm packages typosquatting popular libraries to insert SSH backdoors.
Security News
MITRE's 2024 CWE Top 25 highlights critical software vulnerabilities like XSS, SQL Injection, and CSRF, reflecting shifts due to a refined ranking methodology.
Security News
In this segment of the Risky Business podcast, Feross Aboukhadijeh and Patrick Gray discuss the challenges of tracking malware discovered in open source softare.