
Product
Secure Your AI-Generated Code with Socket MCP
Socket MCP brings real-time security checks to AI-generated code, helping developers catch risky dependencies before they enter the codebase.
Extends unittest.TestCase such that assertions support PyTorch tensors and parameters.
Using Python's unittest
package turns out to be cumbersome when we
are working with PyTorch and need to write assertions that include tensors, parameters, and so
forth.
The main reason for this is that PyTorch tensors are compared element-wise by default, which is why assertions provided
by the class unittest.TestCase
do not work
out-of-the-box.
A possible workaround is to use
TestCase.assertTrue
for any assertion
that we need to make, yet this commonly leads to convoluted code that is hard to read and maintain.
The module torchtestcase
defines the class TorchTestCase
, which extends unittest.TestCase
such that many
assertions support instances of various PyTorch classes.
Updates:
This module can be installed from PyPI:
pip install torchtestcase
This section describes those assertions provided by the class TorchTestCase
that support PyTorch.
If you are not familiar with the package unittest
, then read about it first
here.
Notice:
With the release of PyTorch 0.4.0, tensors and variables have been merged, which means that Variable
s are treated just
like any other tensors, and thus there is no need to make use of the class torch.autograd.Variable
anymore.
Accordingly, assertions for Variable
s in particular have been removed in version 2018.1 of torchtestcase
.
(assertEqual
, assertNotEqual
)
Equality assertions support objects that are any kind of PyTorch tensors as well as instances of torch.nn.Parameter
and torch.nn.utils.rnn.PackedSequence
.
Notice, however, that an AssertionError
is raised if the compared objects are instances of different types:
self.assertEqual(torch.zeros(4), nn.Parameter(torch.zeros(4))) # -> AssertionError
Occasionally, we do not expect two tensors to match each other exactly, which is the case if we anticipate numerical
instabilities, for example.
For any such case, TorchTestCase
provides the possibility to specify a certain tolerance:
self.eps = 0.001 # specify tolerance for equality assertions
self.assertEqual(torch.zeros(3), 0.001 * torch.ones(3)) # -> no AssertionError
self.assertEqual(torch.zeros(3), 0.0011 * torch.ones(3)) # -> AssertionError
Notice that a specified tolerance is taken into account for equality assertions between two tensors only.
(assertGreater
, assertGreaterEqual
, assertLess
, assertLessEqual
)
In general, order assertions are assumed to be fulfilled if they hold element-wise. For example:
x = torch.FloatTensor([0, 0, 1])
y = torch.FloatTensor([1, 1, 1])
self.assertLessEqual(x, y) # -> no AssertionError
self.assertLess(x, y) # -> AssertionError
In addition, it is possible to compare tensors or Parameters
to a number, in which case each element of the considered
data tensor is compared to the same.
For example, if we want to ensure that every element of a tensor lies in the unit interval, then we may use the
following assertions:
self.assertGreaterEqual(some_tensor, 0)
self.assertLessEqual(some_tensor, 1)
When we make order assertions, then we usually do not care about the actual types of the objects involved.
Therefore, it is possible to compare different kinds of tensors with each other as well as with Parameter
s:
self.assertLess(torch.zeros(3), nn.Parameter(torch.ones(3))) # -> no AssertionError
FAQs
Extends unittest.TestCase such that assertions support PyTorch tensors and parameters.
We found that torchtestcase demonstrated a healthy version release cadence and project activity because the last version was released less than a year ago. It has 1 open source maintainer collaborating on the project.
Did you know?
Socket for GitHub automatically highlights issues in each pull request and monitors the health of all your open source dependencies. Discover the contents of your packages and block harmful activity before you install or update your dependencies.
Product
Socket MCP brings real-time security checks to AI-generated code, helping developers catch risky dependencies before they enter the codebase.
Security News
As vulnerability data bottlenecks grow, the federal government is formally investigating NIST’s handling of the National Vulnerability Database.
Research
Security News
Socket’s Threat Research Team has uncovered 60 npm packages using post-install scripts to silently exfiltrate hostnames, IP addresses, DNS servers, and user directories to a Discord-controlled endpoint.