You're Invited:Meet the Socket Team at RSAC and BSidesSF 2026, March 23–26.RSVP
Socket
Book a DemoSign in
Socket

burn-tensor

Package Overview
Dependencies
Maintainers
1
Versions
39
Alerts
File Explorer

Advanced tools

Socket logo

Install Socket

Detect and block malicious and high-risk dependencies

Install

burn-tensor - cargo Package Compare versions

Comparing version
0.21.0-pre.1
to
0.21.0-pre.2
+1
-1
.cargo_vcs_info.json
{
"git": {
"sha1": "d4dbe9ef847e6795027bbe543d712d443c970e9b"
"sha1": "91e784e3fb10ee19400707fcd17669a1ba7f2780"
},
"path_in_vcs": "crates/burn-tensor"
}

@@ -15,3 +15,3 @@ # THIS FILE IS AUTOMATICALLY GENERATED BY CARGO

name = "burn-tensor"
version = "0.21.0-pre.1"
version = "0.21.0-pre.2"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]

@@ -81,7 +81,7 @@ build = false

[dependencies.burn-backend]
version = "=0.21.0-pre.1"
version = "=0.21.0-pre.2"
default-features = false
[dependencies.burn-std]
version = "=0.21.0-pre.1"
version = "=0.21.0-pre.2"
default-features = false

@@ -88,0 +88,0 @@

@@ -563,1 +563,86 @@ use crate::backend::Backend;

}
/// Applies the HardShrink function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{hard\_shrink}(x) =
\begin{cases}
x & \text{if } x > \lambda \newline
x & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`hard_shrink(x) = x if x > lambda, x if x < -lambda, 0 otherwise`"
)]
/// # Arguments
/// - `lambda`: the lambda value for the Hard Shrink formulation. Default is 0.5.
pub fn hard_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
let mask = tensor.clone().abs().lower_equal_elem(lambda);
tensor.mask_fill(mask, 0)
}
/// Applies the SoftShrink function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{soft\_shrink}(x) =
\begin{cases}
x - \lambda & \text{if } x > \lambda \newline
x + \lambda & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`soft_shrink(x) = x - lambda if x > lambda, x + lambda if x < -lambda, 0 otherwise`"
)]
/// # Arguments
/// - `lambda`: the lambda value for the Soft Shrink formulation. Default is 0.5.
pub fn soft_shrink<const D: usize, B: Backend>(tensor: Tensor<B, D>, lambda: f64) -> Tensor<B, D> {
shrink(tensor, lambda, lambda)
}
/// Applies the Shrink function element-wise.
///
#[cfg_attr(
doc,
doc = r#"
$$
\text{shrink}(x) =
\begin{cases}
x - \text{bias} & \text{if } x > \lambda \newline
x + \text{bias} & \text{if } x < -\lambda \newline
0 & \text{otherwise}
\end{cases}
$$
"#
)]
#[cfg_attr(
not(doc),
doc = "`shrink(x) = x - bias if x > lambda, x + bias if x < -lambda, 0 otherwise`"
)]
/// # Arguments
/// - `lambda`: the lambda value for the Shrink formulation.
/// - `bias`: the bias value for the Shrink formulation.
pub fn shrink<const D: usize, B: Backend>(
tensor: Tensor<B, D>,
lambda: f64,
bias: f64,
) -> Tensor<B, D> {
let abs_tensor = tensor.clone().abs();
let sign = tensor.clone().sign();
let shrunk = tensor.sub(sign.mul_scalar(bias));
let mask = abs_tensor.lower_equal_elem(lambda);
shrunk.mask_fill(mask, 0)
}

@@ -242,3 +242,3 @@ use crate::{Bool, Int, Shape, Tensor, TensorData, TensorPrimitive, backend::Backend};

let dims = indices.shape().dims;
let dims = indices.shape();
indices

@@ -245,0 +245,0 @@ .chunk(dims[1], 1)

@@ -123,3 +123,3 @@ use crate::ops::FloatElem;

if start >= tensor.shape().dims[dim] {
if start >= tensor.shape()[dim] {
check = check.register(

@@ -130,3 +130,3 @@ "Narrow",

this dimension (Size={})",
tensor.shape().dims[dim]
tensor.shape()[dim]
)),

@@ -136,3 +136,3 @@ );

if start + length > tensor.shape().dims[dim] {
if start + length > tensor.shape()[dim] {
check = check.register(

@@ -143,3 +143,3 @@ "Narrow",

along this dimension (Size={})",
tensor.shape().dims[dim]
tensor.shape()[dim]
)),

@@ -556,4 +556,4 @@ );

let dim_lhs = shape_lhs.dims[D - 1];
let dim_rhs = shape_rhs.dims[D - 2];
let dim_lhs = shape_lhs[D - 1];
let dim_rhs = shape_rhs[D - 2];

@@ -569,3 +569,3 @@ if dim_lhs != dim_rhs {

"Lhs shape {:?}, rhs shape {:?}.",
shape_lhs.dims, shape_rhs.dims
shape_lhs, shape_rhs
)),

@@ -603,4 +603,4 @@ );

let dim_size_lhs = shape_lhs.dims[dim];
let dim_size_rhs = shape_rhs.dims[dim];
let dim_size_lhs = shape_lhs[dim];
let dim_size_rhs = shape_rhs[dim];

@@ -619,4 +619,4 @@ if dim_size_lhs != 3 || dim_size_rhs != 3 {

if i != dim {
let l = shape_lhs.dims[i];
let r = shape_rhs.dims[i];
let l = shape_lhs[i];
let r = shape_rhs[i];
if l != r && l != 1 && r != 1 {

@@ -716,3 +716,3 @@ check = check.register(

let mut shape_reference = tensors.first().unwrap().shape();
shape_reference.dims[dim] = 1; // We want to check every dims except the one where the
shape_reference[dim] = 1; // We want to check every dims except the one where the
// concatenation happens.

@@ -779,3 +779,3 @@

Tensor shape {:?}.",
end, d_tensor, i, shape.dims,
end, d_tensor, i, shape,
)),

@@ -826,3 +826,3 @@ );

let d_tensor = shape[i];
let d_tensor_value = shape_value.dims[i];
let d_tensor_value = shape_value[i];
let range = slice.to_range(d_tensor);

@@ -840,3 +840,3 @@

Current tensor shape {:?}, value tensor shape {:?}.",
range.start, range.end, d_tensor, i, shape.dims, shape_value.dims,
range.start, range.end, d_tensor, i, shape, shape_value,
)),

@@ -866,4 +866,4 @@ );

i,
shape.dims,
shape_value.dims,
shape,
shape_value,
)),

@@ -918,3 +918,3 @@ );

"The shape differs: {:?} != {:?}",
shape_indices.dims, shape_value.dims
shape_indices, shape_value
)),

@@ -985,3 +985,3 @@ );

if shape_value.dims[dim] != shape_indices.dims[0] {
if shape_value[dim] != shape_indices[0] {
check = check.register(

@@ -992,4 +992,4 @@ "Select Assign",

"Number of indices ({}) should be equal to value tensor dimensions {:?} on axis (dim={dim})",
shape_indices.dims[0],
shape_value.dims
shape_indices[0],
shape_value
),

@@ -1037,3 +1037,3 @@ )

let tensor_dim_i = shape[i];
let indices_dim_i = shape_indices.dims[i];
let indices_dim_i = shape_indices[i];

@@ -1062,7 +1062,7 @@ if tensor_dim_i != indices_dim_i {

let mut check = Self::Ok;
if shape_weight.dims[0] == 1 {
if shape_weight[0] == 1 {
check
} else if D >= 2 {
let channels = shape_tensor.dims[1];
let num_weights = shape_weight.dims[0];
let channels = shape_tensor[1];
let num_weights = shape_weight[0];
if channels != num_weights {

@@ -1089,3 +1089,3 @@ check = check.register(

"Got no. of channels: 1, no. of weights: {}",
shape_weight.dims[0]
shape_weight[0]
)),

@@ -1216,4 +1216,4 @@ );

for i in 0..D {
let d_lhs = lhs.dims[i];
let d_rhs = rhs.dims[i];
let d_lhs = lhs[i];
let d_rhs = rhs[i];

@@ -1233,3 +1233,3 @@ if d_lhs != d_rhs {

broadcasted. Lhs tensor shape {:?}, Rhs tensor shape {:?}.",
i, d_lhs, d_rhs, lhs.dims, rhs.dims,
i, d_lhs, d_rhs, lhs, rhs,
),

@@ -1279,3 +1279,3 @@ ),

let d_to = if i >= start_index_to {
to.dims[i - start_index_to]
to[i - start_index_to]
} else {

@@ -1298,4 +1298,4 @@ 1

d_to,
shape.dims,
to.dims,
shape,
to,
)),

@@ -1375,3 +1375,3 @@ );

let mut check = TensorCheck::Ok;
if shape.dims[D - 1] != shape.dims[D - 2] {
if shape[D - 1] != shape[D - 2] {
check = check.register(

@@ -1381,3 +1381,3 @@ ops,

"Got tensor with shape {:?}, expected last two dimensions to be equal",
shape.dims
shape
)),

@@ -1468,7 +1468,7 @@ );

pub(crate) fn unwrap_shape_reshape(result: Result<Shape, burn_std::ShapeError>) -> Shape {
pub(crate) fn unwrap_shape_reshape(result: Result<Shape, burn_std::MetadataError>) -> Shape {
match result {
Ok(shape) => shape,
// `shape.reshape(new_shape)` should only return `ShapeError::Invalid`.
Err(burn_std::ShapeError::Invalid { reason }) => {
// `shape.reshape(new_shape)` should only return `MetadataError::Invalid`.
Err(burn_std::MetadataError::Invalid { reason }) => {
macros::check!({

@@ -1475,0 +1475,0 @@ TensorCheck::Ok.register("Reshape", crate::check::TensorError::new(reason))

@@ -14,2 +14,3 @@ pub(crate) mod check;

mod pad;
pub use pad::IntoPadding;
mod take;

@@ -16,0 +17,0 @@ mod transaction;

@@ -684,2 +684,7 @@ use burn_backend::Scalar;

/// ```
///
/// # Notes
///
/// For signed integer dtypes, this operation uses two's-complement wraparound semantics, similar to
/// `x.wrapping_abs()`. For example, `abs(i64::MIN) == i64::MIN`.
pub fn abs(self) -> Self {

@@ -727,3 +732,3 @@ Self::new(K::abs(self.primitive))

// last two dimensions
let shape = &self.shape().dims[D - 2..].to_owned();
let shape = &self.shape()[D - 2..].to_owned();

@@ -773,3 +778,3 @@ let mask = Tensor::<B, 2, Bool>::triu_mask(shape, diagonal, &self.device()).unsqueeze();

// last two dimensions
let shape = &self.shape().dims[D - 2..].to_owned();
let shape = &self.shape()[D - 2..].to_owned();
let mask = Tensor::<B, 2, Bool>::tril_mask(shape, diagonal, &self.device()).unsqueeze();

@@ -776,0 +781,0 @@

@@ -8,2 +8,76 @@ use alloc::vec::Vec;

/// Trait for types that can be used as padding specifications.
///
/// Padding is specified as `(before, after)` pairs per dimension, returned as a
/// fixed-size array `[(usize, usize); D]`. If fewer pairs than dimensions are provided,
/// they apply to the **last** N dimensions (earlier dimensions are left unpadded).
pub trait IntoPadding<const D: usize> {
/// Converts into a fixed-size array of `(before, after)` padding pairs.
fn into_padding(self) -> [(usize, usize); D];
}
impl<const D: usize, const N: usize> IntoPadding<D> for [(usize, usize); N] {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
N <= D,
"Padding has {} pairs but tensor only has {} dimensions",
N,
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - N;
for (i, pair) in self.into_iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
/// Backward-compatible: `(left, right, top, bottom)` maps to last 2 dimensions.
///
/// Equivalent to `[(top, bottom), (left, right)]`.
impl<const D: usize> IntoPadding<D> for (usize, usize, usize, usize) {
fn into_padding(self) -> [(usize, usize); D] {
let (left, right, top, bottom) = self;
let mut result = [(0usize, 0usize); D];
result[D - 2] = (top, bottom);
result[D - 1] = (left, right);
result
}
}
impl<const D: usize> IntoPadding<D> for &[(usize, usize)] {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
self.len() <= D,
"Padding has {} pairs but tensor only has {} dimensions",
self.len(),
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - self.len();
for (i, &pair) in self.iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
impl<const D: usize> IntoPadding<D> for Vec<(usize, usize)> {
fn into_padding(self) -> [(usize, usize); D] {
assert!(
self.len() <= D,
"Padding has {} pairs but tensor only has {} dimensions",
self.len(),
D
);
let mut result = [(0usize, 0usize); D];
let offset = D - self.len();
for (i, pair) in self.into_iter().enumerate() {
result[offset + i] = pair;
}
result
}
}
/// Helper to build a range array for slice_assign, selecting a portion of one dimension.

@@ -36,10 +110,18 @@ fn build_slice_ranges<const D: usize>(

{
/// Pads the tensor on the last two dimensions using the specified padding mode.
/// Pads the tensor using the specified padding mode.
///
/// **Note**: Currently, padding is only supported on the last two dimensions of a tensor
/// (typically height and width for image data in NCHW format).
/// Padding is specified as `(before, after)` pairs. If fewer pairs than tensor dimensions
/// are provided, they apply to the **last** N dimensions (unspecified leading dimensions
/// are left unpadded).
///
/// For backward compatibility, a `(left, right, top, bottom)` tuple is also accepted,
/// which pads the last two dimensions.
///
/// # Arguments
///
/// * `padding` - A tuple `(left, right, top, bottom)` specifying padding for the last two dimensions.
/// * `padding` - Padding specification. Accepts:
/// - `[(before, after); N]` fixed-size array of pairs (N <= D)
/// - `&[(before, after)]` slice of pairs per dimension
/// - `Vec<(before, after)>` vector of pairs
/// - `(left, right, top, bottom)` tuple for last-2-dim backward compatibility
/// * `mode` - The padding mode: `Constant(value)`, `Reflect`, or `Edge`.

@@ -53,2 +135,3 @@ ///

///
/// - Panics if more padding pairs are provided than tensor dimensions.
/// - `Reflect` mode panics if padding exceeds `dimension_size - 1`.

@@ -68,25 +151,18 @@ /// - `Edge` mode panics if padding is applied to a zero-sized dimension.

///
/// // Constant padding with value 0.0
/// // Constant padding with value 0.0 (backward-compatible tuple)
/// let padded = tensor.clone().pad((1, 1, 1, 1), PadMode::Constant(0.0));
/// // [
/// // [0.0, 0.0, 0.0, 0.0, 0.0],
/// // [0.0, 12.0, -2.0, 3.0, 0.0],
/// // [0.0, 5.0, 3.0, 6.0, 0.0],
/// // [0.0, 0.0, 0.0, 0.0, 0.0]
/// // ]
///
/// // Reflect padding
/// let padded = tensor.clone().pad((1, 1, 0, 0), PadMode::Reflect);
/// // [[−2.0, 12.0, −2.0, 3.0, −2.0], [3.0, 5.0, 3.0, 6.0, 3.0]]
/// // Pad arbitrary dimensions with slice of (before, after) pairs
/// let padded = tensor.clone().pad([(1, 1), (2, 2)], PadMode::Constant(0.0));
///
/// // Edge padding
/// let padded = tensor.pad((1, 1, 0, 0), PadMode::Edge);
/// // [[12.0, 12.0, −2.0, 3.0, 3.0], [5.0, 5.0, 3.0, 6.0, 6.0]]
/// // Pad only the last dimension
/// let padded = tensor.pad([(1, 1)], PadMode::Reflect);
/// }
/// ```
pub fn pad(self, padding: (usize, usize, usize, usize), mode: impl Into<PadMode>) -> Self {
pub fn pad(self, padding: impl IntoPadding<D>, mode: impl Into<PadMode>) -> Self {
let pairs = padding.into_padding();
match mode.into() {
PadMode::Constant(value) => pad_constant(self, padding, value),
PadMode::Reflect => pad_reflect(self, padding),
PadMode::Edge => pad_edge(self, padding),
PadMode::Constant(value) => pad_constant(self, &pairs, value),
PadMode::Reflect => pad_reflect(self, &pairs),
PadMode::Edge => pad_edge(self, &pairs),
}

@@ -97,5 +173,5 @@ }

/// Pad with a constant value.
pub fn pad_constant<B, const D: usize, K, E>(
fn pad_constant<B, const D: usize, K, E>(
tensor: Tensor<B, D, K>,
padding: (usize, usize, usize, usize),
padding: &[(usize, usize); D],
value: E,

@@ -109,31 +185,21 @@ ) -> Tensor<B, D, K>

{
let (left, right, top, bottom) = padding;
let mut padded_dims: [usize; D] = tensor.dims();
// Update the last two dimensions with padding
padded_dims[D - 2] += top + bottom;
padded_dims[D - 1] += left + right;
for (i, &(before, after)) in padding.iter().enumerate() {
padded_dims[i] += before + after;
}
// Create the ranges for the padded tensor
let ranges: [core::ops::Range<usize>; D] = padded_dims
let ranges: [Range<usize>; D] = padded_dims
.iter()
.enumerate()
.map(|(i, &dim)| {
if i == D - 2 {
top..dim - bottom
} else if i == D - 1 {
left..dim - right
} else {
0..dim
}
let (before, after) = padding[i];
before..dim - after
})
.collect::<Vec<core::ops::Range<usize>>>()
.collect::<Vec<Range<usize>>>()
.try_into()
.unwrap();
// Create the padded tensor
let padded_tensor = Tensor::full(padded_dims, value, &tensor.device());
// Assign the original tensor data to the appropriate slice of the padded tensor
padded_tensor.slice_assign(ranges, tensor)

@@ -146,5 +212,5 @@ }

/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[3, 2, 1, 2, 3, 4]`
pub fn pad_reflect<B, const D: usize, K>(
fn pad_reflect<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: (usize, usize, usize, usize),
padding: &[(usize, usize); D],
) -> Tensor<B, D, K>

@@ -156,34 +222,25 @@ where

{
let (left, right, top, bottom) = padding;
let dims = tensor.dims();
// Validate padding doesn't exceed tensor dimensions
// For reflect mode, padding must be less than the corresponding dimension
assert!(
top < dims[D - 2] && bottom < dims[D - 2],
"Reflect padding on height ({}, {}) must be less than height dimension ({})",
top,
bottom,
dims[D - 2]
);
assert!(
left < dims[D - 1] && right < dims[D - 1],
"Reflect padding on width ({}, {}) must be less than width dimension ({})",
left,
right,
dims[D - 1]
);
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
assert!(
before < dims[i] && after < dims[i],
"Reflect padding ({}, {}) must be less than dimension {} size ({})",
before,
after,
i,
dims[i]
);
}
}
let mut result = tensor;
// Pad height dimension (D - 2): top and bottom
if top > 0 || bottom > 0 {
result = pad_reflect_dim(result, D - 2, top, bottom);
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
result = pad_reflect_dim(result, i, before, after);
}
}
// Pad width dimension (D - 1): left and right
if left > 0 || right > 0 {
result = pad_reflect_dim(result, D - 1, left, right);
}
result

@@ -241,5 +298,5 @@ }

/// Example: `[1, 2, 3, 4]` with left padding 2 becomes `[1, 1, 1, 2, 3, 4]`
pub fn pad_edge<B, const D: usize, K>(
fn pad_edge<B, const D: usize, K>(
tensor: Tensor<B, D, K>,
padding: (usize, usize, usize, usize),
padding: &[(usize, usize); D],
) -> Tensor<B, D, K>

@@ -251,31 +308,22 @@ where

{
let (left, right, top, bottom) = padding;
let dims = tensor.dims();
// Validate dimensions are non-zero when padding is requested
if top > 0 || bottom > 0 {
assert!(
dims[D - 2] > 0,
"Cannot apply edge padding to zero-sized height dimension"
);
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
assert!(
dims[i] > 0,
"Cannot apply edge padding to zero-sized dimension {}",
i
);
}
}
if left > 0 || right > 0 {
assert!(
dims[D - 1] > 0,
"Cannot apply edge padding to zero-sized width dimension"
);
}
let mut result = tensor;
// Pad height dimension (D - 2): top and bottom
if top > 0 || bottom > 0 {
result = pad_edge_dim(result, D - 2, top, bottom);
for (i, &(before, after)) in padding.iter().enumerate() {
if before > 0 || after > 0 {
result = pad_edge_dim(result, i, before, after);
}
}
// Pad width dimension (D - 1): left and right
if left > 0 || right > 0 {
result = pad_edge_dim(result, D - 1, left, right);
}
result

@@ -282,0 +330,0 @@ }

@@ -7,4 +7,4 @@ use crate::{

ops::{
ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode, PaddedConvOptions,
UnfoldOptions,
AttentionModuleOptions, ConvOptions, ConvTransposeOptions, InterpolateOptions, PadMode,
PaddedConvOptions, UnfoldOptions,
},

@@ -498,4 +498,5 @@ };

/// Computes scaled dot-product attention: softmax(QKᵗ / √d) · V,
/// optionally applying a 4D mask to the attention scores.
/// Computes scaled dot-product attention: softmax(QKᵗ * scale) · V,
/// where scale defaults to 1/sqrt(head_dim) (configurable via `options.scale`).
/// Optionally applies masking, additive bias, causal masking, and softcap.
///

@@ -505,8 +506,11 @@ /// # Arguments

/// - `key`: Key tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
/// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, head_dim]`
/// - `value`: Value tensor of shape `[batch_size, num_heads, seq_len_k, val_dim]`
/// - `mask`: Optional boolean mask of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`,
/// where `true` indicates positions to mask (i.e. set to -∞ before softmax).
/// where `true` indicates positions to mask (i.e. set to -inf before softmax).
/// - `attn_bias`: Optional float tensor of shape `[batch_size, num_heads, seq_len_q, seq_len_k]`
/// added to the attention scores before softmax (e.g. ALiBi, relative position biases).
/// - `options`: Additional attention options (custom scale, softcap, causal masking).
///
/// # Returns
/// A tensor of shape `[batch_size, num_heads, seq_len_q, head_dim]`
/// A tensor of shape `[batch_size, num_heads, seq_len_q, val_dim]`
/// representing the attended context per head.

@@ -522,2 +526,4 @@ ///

mask: Option<Tensor<B, 4, Bool>>,
attn_bias: Option<Tensor<B, 4>>,
options: AttentionModuleOptions,
) -> Tensor<B, 4> {

@@ -529,7 +535,9 @@ Tensor::new(TensorPrimitive::Float(B::attention(

mask.map(|mask| mask.primitive),
attn_bias.map(|bias| bias.primitive.tensor()),
options,
)))
}
/// Exports naive attention to test backend's attention against
pub fn naive_attention<B: Backend>(
/// Exports attention fallback to test backend's attention against.
pub fn attention_fallback<B: Backend>(
query: Tensor<B, 4>,

@@ -539,5 +547,7 @@ key: Tensor<B, 4>,

mask: Option<Tensor<B, 4, Bool>>,
attn_bias: Option<Tensor<B, 4>>,
options: AttentionModuleOptions,
) -> Tensor<B, 4> {
Tensor::new(TensorPrimitive::Float(
crate::ops::attention::naive_attention::<B>(
crate::ops::attention::attention_fallback::<B>(
query.primitive.tensor(),

@@ -547,4 +557,6 @@ key.primitive.tensor(),

mask.map(|mask| mask.primitive),
attn_bias.map(|bias| bias.primitive.tensor()),
options,
),
))
}

@@ -35,3 +35,3 @@ use alloc::format;

"NamedTensor[shape={:?}, dims={}]",
self.shape().dims,
self.shape(),
ND::to_string(),

@@ -38,0 +38,0 @@ ))

@@ -14,3 +14,3 @@ use crate::{Tensor, backend::Backend};

) -> Tensor<B, D> {
let n = tensor.shape().dims[dim] - 1;
let n = tensor.shape()[dim] - 1;
var_with_mean_n(tensor, mean, dim, n)

@@ -29,3 +29,3 @@ }

) -> Tensor<B, D> {
let n = tensor.shape().dims[dim];
let n = tensor.shape()[dim];
var_with_mean_n(tensor, mean, dim, n)

@@ -32,0 +32,0 @@ }

Sorry, the diff of this file is too big to display

Sorry, the diff of this file is not supported yet

Sorry, the diff of this file is too big to display